diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 1b1478ae6..86c1b326f 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -324,3 +324,30 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names # SURFSENSE_ALLOWED_PLUGINS=year_substituter + +# ----------------------------------------------------------------------------- +# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON) +# ----------------------------------------------------------------------------- +# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU +# on a cold turn) is reused across subsequent turns on the same thread, +# collapsing it to a microsecond hash lookup. All connector tools acquire +# their own short-lived DB session per call (Phase 2 refactor) so a cached +# closure is safe to share across requests. Flip OFF only as a last-resort +# rollback if you suspect cache-related staleness. +# SURFSENSE_ENABLE_AGENT_CACHE=true + +# Cache capacity (max number of compiled-agent entries kept in memory) +# and TTL per entry (seconds). Working set is typically one entry per +# active thread on this replica; tune up for very large deployments. +# SURFSENSE_AGENT_CACHE_MAXSIZE=256 +# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800 + +# ----------------------------------------------------------------------------- +# Connector discovery TTL cache (Phase 1.4 perf optimization) +# ----------------------------------------------------------------------------- +# Caches the per-search-space "available connectors" + "available document +# types" lookups that ``create_surfsense_deep_agent`` hits on every turn. +# ORM event listeners auto-invalidate on connector / document inserts, +# updates and deletes — the TTL only bounds staleness for bulk-import +# paths that bypass the ORM. Set to 0 to disable the cache. +# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30 diff --git a/surfsense_backend/app/agents/new_chat/agent_cache.py b/surfsense_backend/app/agents/new_chat/agent_cache.py new file mode 100644 index 000000000..fa8e6fb72 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/agent_cache.py @@ -0,0 +1,357 @@ +"""TTL-LRU cache for compiled SurfSense deep agents. + +Why this exists +--------------- + +``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat +turn: + +1. Discover connectors & document types from Postgres (~50-200ms) +2. Build the tool list (built-in + MCP) (~200ms-1.7s) +3. Compose the system prompt +4. Construct ~15 middleware instances (CPU) +5. Eagerly compile the general-purpose subagent + (``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously, + which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure + CPU work) +6. Compile the outer LangGraph + +For a single thread, all six steps produce the SAME object on every turn +unless the user has changed their LLM config, toggled a feature flag, +added a connector, etc. The right answer is to compile ONCE per +"agent shape" and reuse the resulting :class:`CompiledStateGraph` for +every subsequent turn on the same thread. + +Why a per-thread key (not a global pool) +---------------------------------------- + +Most middleware in the SurfSense stack captures per-thread state in +``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``, +``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse +would silently leak state across users and threads. Keying the cache on +``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated +turns on the same thread without changing any middleware's behavior. + +Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema` +(read via ``runtime.context``) so the cache can collapse to a single +``(llm_config_id, search_space_id, ...)`` key shared across threads. Until +then, per-thread keying is the only safe option. + +Cache shape +----------- + +* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30 + minutes — matches a typical chat session). ``maxsize`` (default 256) + caps memory; LRU evicts least-recently-used on overflow. +* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent + cold misses on the same key wait for the first build instead of + building N times. +* Process-local: this is an in-memory cache. Multi-replica deployments + pay the build cost once per replica per key. That's fine; the working + set per replica is small (one entry per active thread on that replica). + +Telemetry +--------- + +Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``: + + * ``hit`` — cache hit, microseconds-fast + * ``miss`` — first build for this key, includes build duration + * ``stale`` — entry was found but expired; rebuilt + * ``evict`` — LRU eviction (size-limited) + * ``size`` — current cache occupancy at lookup time +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +# --------------------------------------------------------------------------- +# Public API: signature helpers (cache key components) +# --------------------------------------------------------------------------- + + +def stable_hash(*parts: Any) -> str: + """Compute a deterministic SHA1 of the str repr of ``parts``. + + Used for cache key components that need a fixed-width representation + (system prompt, tool list, etc.). SHA1 is fine here — this is not a + security boundary, just a content fingerprint. + """ + h = hashlib.sha1(usedforsecurity=False) + for p in parts: + h.update(repr(p).encode("utf-8", errors="replace")) + h.update(b"\x1f") # ASCII unit separator between parts + return h.hexdigest() + + +def tools_signature( + tools: list[Any] | tuple[Any, ...], + *, + available_connectors: list[str] | None, + available_document_types: list[str] | None, +) -> str: + """Hash the bound-tool surface for cache-key purposes. + + The signature changes whenever: + + * A tool is added or removed from the bound list (built-in toggles, + MCP tools loaded for the user changes, gating rules flip, etc.). + * The available connectors / document types for the search space + change (new connector added, last connector removed, new document + type indexed). Because :func:`get_connector_gated_tools` derives + ``modified_disabled_tools`` from ``available_connectors``, the + tool surface is technically already covered — but we hash the + connector list separately so an empty-list "no tools changed" + situation still rotates the key when, say, the user re-adds a + connector that gates a tool we were already not exposing. + + Stays stable across: + + * Process restarts (tool names + descriptions are static). + * Different replicas (everyone gets the same hash for the same + inputs). + """ + tool_descriptors = sorted( + (getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools + ) + connectors = sorted(available_connectors or []) + doc_types = sorted(available_document_types or []) + return stable_hash(tool_descriptors, connectors, doc_types) + + +def flags_signature(flags: Any) -> str: + """Hash the resolved :class:`AgentFeatureFlags` dataclass. + + Frozen dataclasses are deterministically reprable, so a SHA1 of their + repr is a stable fingerprint. Restart safe (flags are read once at + process boot). + """ + return stable_hash(repr(flags)) + + +def system_prompt_hash(system_prompt: str) -> str: + """Hash a system prompt string. Cheap, ~30µs for typical prompts.""" + return hashlib.sha1( + system_prompt.encode("utf-8", errors="replace"), + usedforsecurity=False, + ).hexdigest() + + +# --------------------------------------------------------------------------- +# Cache implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _Entry: + value: Any + created_at: float + last_used_at: float + + +class _AgentCache: + """In-process TTL-LRU cache with per-key in-flight de-duplication. + + NOT THREAD-SAFE in the multithreading sense — designed for a single + asyncio event loop. Uvicorn runs one event loop per worker process, + so this is fine; multi-worker deployments simply each maintain their + own cache. + """ + + def __init__(self, *, maxsize: int, ttl_seconds: float) -> None: + self._maxsize = maxsize + self._ttl = ttl_seconds + self._entries: OrderedDict[str, _Entry] = OrderedDict() + # One lock per key — guards "build" so concurrent cold misses on + # the same key wait for the first build instead of all racing. + self._locks: dict[str, asyncio.Lock] = {} + + def _now(self) -> float: + return time.monotonic() + + def _is_fresh(self, entry: _Entry) -> bool: + return (self._now() - entry.created_at) < self._ttl + + def _evict_if_full(self) -> None: + while len(self._entries) >= self._maxsize: + evicted_key, _ = self._entries.popitem(last=False) + self._locks.pop(evicted_key, None) + _perf_log.info( + "[agent_cache] evict key=%s reason=lru size=%d", + _short(evicted_key), + len(self._entries), + ) + + def _touch(self, key: str, entry: _Entry) -> None: + entry.last_used_at = self._now() + self._entries.move_to_end(key, last=True) + + async def get_or_build( + self, + key: str, + *, + builder: Callable[[], Awaitable[Any]], + ) -> Any: + """Return the cached value for ``key`` or call ``builder()`` to make it. + + ``builder`` MUST be idempotent — concurrent cold misses on the + same key collapse to a single ``builder()`` call (the others + wait on the in-flight lock and observe the populated entry on + wake). + """ + # Fast path: hot hit. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + # Stale entry — drop it; rebuild below. + if entry is not None and not self._is_fresh(entry): + _perf_log.info( + "[agent_cache] stale key=%s age=%.1fs ttl=%.0fs", + _short(key), + self._now() - entry.created_at, + self._ttl, + ) + self._entries.pop(key, None) + + # Slow path: serialize concurrent misses for the same key. + lock = self._locks.setdefault(key, asyncio.Lock()) + async with lock: + # Double-check after acquiring the lock — another waiter may + # have populated the entry while we slept. + entry = self._entries.get(key) + if entry is not None and self._is_fresh(entry): + self._touch(key, entry) + _perf_log.info( + "[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true", + _short(key), + self._now() - entry.created_at, + len(self._entries), + ) + return entry.value + + t0 = time.perf_counter() + try: + value = await builder() + except BaseException: + # Don't cache failed builds; let the next caller retry. + _perf_log.warning( + "[agent_cache] build_failed key=%s elapsed=%.3fs", + _short(key), + time.perf_counter() - t0, + ) + raise + elapsed = time.perf_counter() - t0 + + # Insert + evict. + self._evict_if_full() + now = self._now() + self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now) + self._entries.move_to_end(key, last=True) + _perf_log.info( + "[agent_cache] miss key=%s build=%.3fs size=%d", + _short(key), + elapsed, + len(self._entries), + ) + return value + + def invalidate(self, key: str) -> bool: + """Drop a single entry; return True if anything was removed.""" + removed = self._entries.pop(key, None) is not None + self._locks.pop(key, None) + if removed: + _perf_log.info( + "[agent_cache] invalidate key=%s size=%d", + _short(key), + len(self._entries), + ) + return removed + + def invalidate_prefix(self, prefix: str) -> int: + """Drop every entry whose key starts with ``prefix``. Returns count.""" + keys = [k for k in self._entries if k.startswith(prefix)] + for k in keys: + self._entries.pop(k, None) + self._locks.pop(k, None) + if keys: + _perf_log.info( + "[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d", + _short(prefix), + len(keys), + len(self._entries), + ) + return len(keys) + + def clear(self) -> None: + n = len(self._entries) + self._entries.clear() + self._locks.clear() + if n: + _perf_log.info("[agent_cache] clear removed=%d", n) + + def stats(self) -> dict[str, Any]: + return { + "size": len(self._entries), + "maxsize": self._maxsize, + "ttl_seconds": self._ttl, + } + + +def _short(key: str, n: int = 16) -> str: + """Truncate keys for log lines so they don't blow up log volume.""" + return key if len(key) <= n else f"{key[:n]}..." + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + +_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256")) +_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800")) + +_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL) + + +def get_cache() -> _AgentCache: + """Return the process-wide compiled-agent cache singleton.""" + return _cache + + +def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache: + """Replace the singleton with a fresh cache. Tests only.""" + global _cache + _cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds) + return _cache + + +__all__ = [ + "flags_signature", + "get_cache", + "reload_for_tests", + "stable_hash", + "system_prompt_hash", + "tools_signature", +] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index c0e9a3b96..36739adae 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.agent_cache import ( + flags_signature, + get_cache, + stable_hash, + system_prompt_hash, + tools_signature, +) from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver @@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import ( DedupHITLToolCallsMiddleware, DoomLoopMiddleware, FileIntentMiddleware, + FlattenSystemMessageMiddleware, KnowledgeBasePersistenceMiddleware, KnowledgePriorityMiddleware, KnowledgeTreeMiddleware, @@ -330,23 +338,39 @@ async def create_surfsense_deep_agent( else None, ) - # Discover available connectors and document types for this search space + # Discover available connectors and document types for this search space. + # + # NOTE: These two calls cannot be parallelized via ``asyncio.gather``. + # ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``); + # SQLAlchemy explicitly forbids concurrent operations on the same session + # ("This session is provisioning a new connection; concurrent operations + # are not permitted on the same session"). The Phase 1.4 in-process TTL + # cache in ``connector_service`` already collapses the warm path to a + # near-zero pair of dict lookups, so sequential awaits cost nothing in + # the common case while remaining correct on cold cache misses. available_connectors: list[str] | None = None available_document_types: list[str] | None = None _t0 = time.perf_counter() try: - connector_types = await connector_service.get_available_connectors( - search_space_id - ) - if connector_types: - available_connectors = _map_connectors_to_searchable_types(connector_types) + try: + connector_types_result = await connector_service.get_available_connectors( + search_space_id + ) + if connector_types_result: + available_connectors = _map_connectors_to_searchable_types( + connector_types_result + ) + except Exception as e: + logging.warning("Failed to discover available connectors: %s", e) - available_document_types = await connector_service.get_available_document_types( - search_space_id - ) - - except Exception as e: + try: + available_document_types = ( + await connector_service.get_available_document_types(search_space_id) + ) + except Exception as e: + logging.warning("Failed to discover available document types: %s", e) + except Exception as e: # pragma: no cover - defensive outer guard logging.warning(f"Failed to discover available connectors/document types: {e}") _perf_log.info( "[create_agent] Connector/doc-type discovery in %.3fs", @@ -469,29 +493,77 @@ async def create_surfsense_deep_agent( # entire middleware build + main-graph compile into a single # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # event loop stays responsive. + # + # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed + # on every per-request value that any middleware in the stack closes + # over in ``__init__`` — drop one and you risk leaking state across + # threads. Hits collapse this whole block to a microsecond lookup; + # misses pay the original CPU cost AND populate the cache. + config_id = agent_config.config_id if agent_config is not None else None + + async def _build_agent() -> Any: + return await asyncio.to_thread( + _build_compiled_agent_blocking, + llm=llm, + tools=tools, + final_system_prompt=final_system_prompt, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_selection.mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + visibility=visibility, + anon_session_id=anon_session_id, + available_connectors=available_connectors, + available_document_types=available_document_types, + # ``mentioned_document_ids`` is consumed by + # ``KnowledgePriorityMiddleware`` per turn via + # ``runtime.context`` (Phase 1.5). We still pass the + # caller-provided list here for the legacy fallback path + # (cache disabled / context not propagated) — the middleware + # drains its own copy after the first read so a cached graph + # never replays stale mentions. + mentioned_document_ids=mentioned_document_ids, + max_input_tokens=_max_input_tokens, + flags=_flags, + checkpointer=checkpointer, + ) + _t0 = time.perf_counter() - agent = await asyncio.to_thread( - _build_compiled_agent_blocking, - llm=llm, - tools=tools, - final_system_prompt=final_system_prompt, - backend_resolver=backend_resolver, - filesystem_mode=filesystem_selection.mode, - search_space_id=search_space_id, - user_id=user_id, - thread_id=thread_id, - visibility=visibility, - anon_session_id=anon_session_id, - available_connectors=available_connectors, - available_document_types=available_document_types, - mentioned_document_ids=mentioned_document_ids, - max_input_tokens=_max_input_tokens, - flags=_flags, - checkpointer=checkpointer, - ) + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack: + # Cache key components — order matters only for human readability; + # the resulting hash is what's stored. Every component must + # rotate on a real shape change AND stay stable across identical + # invocations. + cache_key = stable_hash( + "v1", # schema version of the key — bump if components change + config_id, + thread_id, + user_id, + search_space_id, + visibility, + filesystem_selection.mode, + anon_session_id, + tools_signature( + tools, + available_connectors=available_connectors, + available_document_types=available_document_types, + ), + flags_signature(_flags), + system_prompt_hash(final_system_prompt), + _max_input_tokens, + # ``mentioned_document_ids`` deliberately omitted — middleware + # reads it from ``runtime.context`` (Phase 1.5). + ) + agent = await get_cache().get_or_build(cache_key, builder=_build_agent) + else: + agent = await _build_agent() _perf_log.info( - "[create_agent] Middleware stack + graph compiled in %.3fs", + "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)", time.perf_counter() - _t0, + "on" + if _flags.enable_agent_cache and not _flags.disable_new_agent_stack + else "off", ) _perf_log.info( @@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking( noop_mw, retry_mw, fallback_mw, + # Coalesce a multi-text-block system message into one block + # immediately before the model call. Sits innermost on the + # system-message-mutation chain so it observes every appender + # (todo / filesystem / skills / subagents …) and prevents + # OpenRouter→Anthropic from redistributing ``cache_control`` + # across N blocks and tripping Anthropic's 4-breakpoint cap. + # See ``middleware/flatten_system.py`` for full rationale. + FlattenSystemMessageMiddleware(), # Tool-call repair must run after model emits but before # permission / dedup / doom-loop interpret the calls. repair_mw, diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index c1fe45aaa..d720b524b 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -1,10 +1,25 @@ """ Context schema definitions for SurfSense agents. -This module defines the custom state schema used by the SurfSense deep agent. +This module defines the per-invocation context object passed to the SurfSense +deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6). + +The agent's compiled graph is the same across invocations (and cached by +``agent_cache``), so anything that varies per turn — the user mentions a +specific document, the front-end issues a unique ``request_id``, etc. — +MUST live on this context object instead of being captured into a +middleware ``__init__`` closure. Middlewares read fields back via +``runtime.context.``; tools read them via ``runtime.context``. + +This object is read inside both ``KnowledgePriorityMiddleware`` (for +``mentioned_document_ids``) and any future middleware that needs +per-request state without invalidating the compiled-agent cache. """ -from typing import NotRequired, TypedDict +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TypedDict class FileOperationContractState(TypedDict): @@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict): turn_id: str -class SurfSenseContextSchema(TypedDict): +@dataclass +class SurfSenseContextSchema: """ - Custom state schema for the SurfSense deep agent. + Per-invocation context for the SurfSense deep agent. - This extends the default agent state with custom fields. - The default state already includes: - - messages: Conversation history - - todos: Task list from TodoListMiddleware - - files: Virtual filesystem from FilesystemMiddleware + Defaults are chosen so the dataclass can be safely default-constructed + (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no + context is supplied — see ``langgraph.runtime.Runtime``). All fields + are optional; consumers must None-check before reading. - We're adding fields needed for knowledge base search: - - search_space_id: The user's search space ID - - db_session: Database session (injected at runtime) - - connector_service: Connector service instance (injected at runtime) + Phase 1.5 fields: + search_space_id: Search space the request is scoped to. + mentioned_document_ids: KB documents the user @-mentioned this turn. + Read by ``KnowledgePriorityMiddleware`` to seed its priority + list. Stays out of the compiled-agent cache key — that's the + whole point of putting it here. + file_operation_contract: One-shot file operation contract emitted + by ``FileIntentMiddleware`` for the upcoming turn. + turn_id / request_id: Correlation IDs surfaced by the streaming + task; populated for telemetry. + + Phase 2 will extend with: thread_id, user_id, visibility, + filesystem_mode, anon_session_id, available_connectors, + available_document_types, created_by_id (everything currently captured + by middleware ``__init__`` closures). """ - search_space_id: int - file_operation_contract: NotRequired[FileOperationContractState] - turn_id: NotRequired[str] - request_id: NotRequired[str] - # These are runtime-injected and won't be serialized - # db_session and connector_service are passed when invoking the agent + search_space_id: int | None = None + mentioned_document_ids: list[int] = field(default_factory=list) + file_operation_contract: FileOperationContractState | None = None + turn_id: str | None = None + request_id: str | None = None diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 5007d89a5..1f5a08ec6 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -103,6 +103,41 @@ class AgentFeatureFlags: # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) enable_otel: bool = False + # Performance — compiled-agent cache (Phase 1 + Phase 2). + # When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled + # graph if the cache key matches (LLM config + thread + tool surface + + # flags + system prompt + filesystem mode). Cuts per-turn agent-build + # wall clock from ~4-5s to <50µs on cache hits. + # + # SAFETY (Phase 2 unblocked this default-on): + # All connector mutation tools (``tools/notion``, ``tools/gmail``, + # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, + # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, + # ``tools/teams``, ``tools/luma``, ``connected_accounts``, + # ``update_memory``, ``search_surfsense_docs``) now acquire fresh + # short-lived ``AsyncSession`` instances per call via + # :data:`async_session_maker`. The factory still accepts ``db_session`` + # for registry compatibility but ``del``'s it immediately — see any + # of those files' factory docstrings for the rationale. The ``llm`` + # closure is per-(provider, model, config_id) which is already in + # the cache key, so the LLM is safe to share across cached hits of + # the same key. The KB priority middleware reads + # ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5), + # not its constructor closure, so the same compiled agent serves + # turns with different mention lists correctly. + # + # Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the + # environment if a regression surfaces. The path is exercised by + # the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite. + enable_agent_cache: bool = True + # Phase 1 (deferred — measure first): pre-build & share the + # general-purpose subagent ``CompiledSubAgent`` across cold-cache + # misses. Only helps when the outer cache MISSES (cache hits already + # reuse the entire SubAgentMiddleware-compiled graph). Off by default + # until we have data showing cold misses are frequent enough to + # justify the extra global state. + enable_agent_cache_share_gp_subagent: bool = False + @classmethod def from_env(cls) -> AgentFeatureFlags: """Read flags from environment. @@ -137,6 +172,8 @@ class AgentFeatureFlags: enable_stream_parity_v2=False, enable_plugin_loader=False, enable_otel=False, + enable_agent_cache=False, + enable_agent_cache_share_gp_subagent=False, ) return cls( @@ -179,6 +216,11 @@ class AgentFeatureFlags: enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), + # Performance + enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True), + enable_agent_cache_share_gp_subagent=_env_bool( + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False + ), ) def any_new_middleware_enabled(self) -> bool: diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 094c102f8..6742bd8de 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import ( from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, +) from app.agents.new_chat.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, commit_staged_filesystem_state, @@ -61,6 +64,7 @@ __all__ = [ "DedupHITLToolCallsMiddleware", "DoomLoopMiddleware", "FileIntentMiddleware", + "FlattenSystemMessageMiddleware", "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", diff --git a/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py new file mode 100644 index 000000000..29cd57aa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/flatten_system.py @@ -0,0 +1,233 @@ +r"""Coalesce multi-block system messages into a single text block. + +Several middlewares in our deepagent stack each call +``append_to_system_message`` on the way down to the model +(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``, +``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the +request reaches the LLM, the system message has 5+ separate text blocks. + +Anthropic enforces a hard cap of **4 ``cache_control`` blocks per +request**, and we configure 2 injection points +(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting +the prepended ``request.system_message``, this middleware is the +defensive partner: it guarantees that "the system block" is *one* +content block, so LiteLLM's ``AnthropicCacheControlHook`` and any +OpenRouter→Anthropic transformer can never multiply our budget into +several breakpoints by spreading ``cache_control`` across multiple +text blocks of a multi-block system content. + +Without flattening we used to see:: + + OpenrouterException - {"error":{"message":"Provider returned error", + "code":400,"metadata":{"raw":"...A maximum of 4 blocks with + cache_control may be provided. Found 5."}}} + +(Same error class documented in +https://github.com/BerriAI/litellm/issues/15696 and +https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix +in PR #15395 covers the litellm transformer but does not protect us +when the OpenRouter SaaS itself does the redistribution.) + +A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching +the first injection point from ``role: system`` to ``index: 0``) +neutralises the *primary* cause of the same 400 — multiple +``SystemMessage``\ s injected by ``before_agent`` middlewares +(priority/tree/memory/file-intent/anonymous-doc) accumulating across +turns, each tagged with ``cache_control`` by the ``role: system`` +matcher. This middleware remains useful as defence-in-depth against +the multi-block redistribution path. + +Placement: innermost on the system-message-mutation chain, after every +appender (``todo``/``filesystem``/``skills``/``subagents``) and after +summarization, but before ``noop``/``retry``/``fallback`` so each retry +attempt sees a flattened payload. See ``chat_deepagent.py``. + +Idempotent: a string-content system message is left untouched. A list +that contains anything other than plain text blocks (e.g. an image) is +also left untouched — those are rare on system messages and we'd lose +the non-text payload by joining. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import SystemMessage + +logger = logging.getLogger(__name__) + + +def _flatten_text_blocks(content: list[Any]) -> str | None: + """Return joined text if every block is a plain ``{"type": "text"}``. + + Returns ``None`` when the list contains anything that isn't a text + block we can safely concatenate (image, audio, file, non-standard + blocks, dicts with extra non-cache_control fields). The caller + leaves the original content untouched in that case rather than + silently dropping payload. + + ``cache_control`` on individual blocks is intentionally discarded — + the whole point of flattening is to let LiteLLM's + ``cache_control_injection_points`` re-place a single breakpoint on + the resulting one-block system content. + """ + chunks: list[str] = [] + for block in content: + if isinstance(block, str): + chunks.append(block) + continue + if not isinstance(block, dict): + return None + if block.get("type") != "text": + return None + text = block.get("text") + if not isinstance(text, str): + return None + chunks.append(text) + return "\n\n".join(chunks) + + +def _flattened_request( + request: ModelRequest[ContextT], +) -> ModelRequest[ContextT] | None: + """Return a request with system_message flattened, or ``None`` for no-op.""" + sys_msg = request.system_message + if sys_msg is None: + return None + content = sys_msg.content + if not isinstance(content, list) or len(content) <= 1: + return None + + flattened = _flatten_text_blocks(content) + if flattened is None: + return None + + new_sys = SystemMessage( + content=flattened, + additional_kwargs=dict(sys_msg.additional_kwargs), + response_metadata=dict(sys_msg.response_metadata), + ) + if sys_msg.id is not None: + new_sys.id = sys_msg.id + return request.override(system_message=new_sys) + + +def _diagnostic_summary(request: ModelRequest[Any]) -> str: + """One-line dump of cache_control-relevant request shape. + + Temporary diagnostic to prove where the ``Found N`` cache_control + breakpoints are coming from when Anthropic 400s. Removed once the + root cause is confirmed and a fix is in place. + """ + sys_msg = request.system_message + if sys_msg is None: + sys_shape = "none" + elif isinstance(sys_msg.content, str): + sys_shape = f"str(len={len(sys_msg.content)})" + elif isinstance(sys_msg.content, list): + sys_shape = f"list(blocks={len(sys_msg.content)})" + else: + sys_shape = f"other({type(sys_msg.content).__name__})" + + role_hist: list[str] = [] + multi_block_msgs = 0 + msgs_with_cc = 0 + sys_msgs_in_history = 0 + for m in request.messages: + mtype = getattr(m, "type", type(m).__name__) + role_hist.append(mtype) + if isinstance(m, SystemMessage): + sys_msgs_in_history += 1 + c = getattr(m, "content", None) + if isinstance(c, list): + multi_block_msgs += 1 + for blk in c: + if isinstance(blk, dict) and "cache_control" in blk: + msgs_with_cc += 1 + break + if "cache_control" in getattr(m, "additional_kwargs", {}) or {}: + msgs_with_cc += 1 + + tools = request.tools or [] + tools_with_cc = 0 + for t in tools: + if isinstance(t, dict) and ( + "cache_control" in t or "cache_control" in t.get("function", {}) + ): + tools_with_cc += 1 + + return ( + f"sys={sys_shape} msgs={len(request.messages)} " + f"sys_msgs_in_history={sys_msgs_in_history} " + f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} " + f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} " + f"roles={role_hist[-8:]}" + ) + + +class FlattenSystemMessageMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): + """Collapse a multi-text-block system message to a single string. + + Sits innermost on the system-message-mutation chain so it observes + every middleware's contribution. Has no other side effect — the + body of every block is preserved, just joined with ``"\\n\\n"``. + """ + + def __init__(self) -> None: + super().__init__() + self.tools = [] + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return handler(flattened) + return handler(request) + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> Any: + if logger.isEnabledFor(logging.DEBUG): + logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) + flattened = _flattened_request(request) + if flattened is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[flatten_system] collapsed %d system blocks to one", + len(request.system_message.content), # type: ignore[arg-type, union-attr] + ) + return await handler(flattened) + return await handler(request) + + +__all__ = [ + "FlattenSystemMessageMiddleware", + "_flatten_text_blocks", + "_flattened_request", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 0820e8c3e..ee5c1d182 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: - del runtime if self.filesystem_mode != FilesystemMode.CLOUD: return None @@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] if anon_doc: return self._anon_priority(state, anon_doc) - return await self._authenticated_priority(state, messages, user_text) + return await self._authenticated_priority(state, messages, user_text, runtime) def _anon_priority( self, @@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] state: AgentState, messages: Sequence[BaseMessage], user_text: str, + runtime: Runtime[Any] | None = None, ) -> dict[str, Any]: t0 = asyncio.get_event_loop().time() ( @@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) + # Per-turn ``mentioned_document_ids`` flow: + # 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the + # streaming task supplies a fresh :class:`SurfSenseContextSchema` + # on every ``astream_events`` call, so this list is naturally + # scoped to the current turn. Allows cross-turn graph reuse via + # ``agent_cache``. + # 2. Legacy fallback (cache disabled / context not propagated): the + # constructor-injected ``self.mentioned_document_ids`` list. We + # drain it after the first read so a cached graph (no Phase 1.5 + # wiring) doesn't keep replaying the same mentions on every + # turn. + # + # CRITICAL: distinguish "context absent" (legacy caller, no field at + # all) from "context provided but empty" (turn with no mentions). + # ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in + # Python, so a naive ``if ctx_mentions:`` would fall through to the + # legacy closure on every no-mention follow-up turn — replaying the + # mentions baked in by turn 1's cache-miss build. Always drain the + # closure once the runtime path has fired so a cached middleware + # instance can never resurrect stale state. + mention_ids: list[int] = [] + ctx = getattr(runtime, "context", None) if runtime is not None else None + ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None + if ctx_mentions is not None: + # Runtime path is authoritative — even an empty list means + # "this turn has no mentions", NOT "look at the closure". + mention_ids = list(ctx_mentions) + if self.mentioned_document_ids: + self.mentioned_document_ids = [] + elif self.mentioned_document_ids: + mention_ids = list(self.mentioned_document_ids) + self.mentioned_document_ids = [] + mentioned_results: list[dict[str, Any]] = [] - if self.mentioned_document_ids: + if mention_ids: mentioned_results = await fetch_mentioned_documents( - document_ids=self.mentioned_document_ids, + document_ids=mention_ids, search_space_id=self.search_space_id, ) - self.mentioned_document_ids = [] if is_recency: doc_types = _resolve_search_types( diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py index 86bc57725..9fe47cdac 100644 --- a/surfsense_backend/app/agents/new_chat/prompt_caching.py +++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py @@ -1,4 +1,4 @@ -"""LiteLLM-native prompt caching configuration for SurfSense agents. +r"""LiteLLM-native prompt caching configuration for SurfSense agents. Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` @@ -17,8 +17,20 @@ Coverage: We inject **two** breakpoints per request: -- ``role: system`` — pins the SurfSense system prompt (provider variant, - citation rules, tool catalog, KB tree, skills metadata) into the cache. +- ``index: 0`` — pins the SurfSense system prompt at the head of the + request (provider variant, citation rules, tool catalog, KB tree, + skills metadata). The langchain agent factory always prepends + ``request.system_message`` at index 0 (see ``factory.py`` + ``_execute_model_async``), so this targets exactly the main system + prompt regardless of how many other ``SystemMessage``\ s the + ``before_agent`` injectors (priority, tree, memory, file-intent, + anonymous-doc) have inserted into ``state["messages"]``. Using + ``role: system`` here would apply ``cache_control`` to **every** + system-role message and trip Anthropic's hard cap of 4 cache + breakpoints per request once the conversation accumulates enough + injected system messages — which surfaces as the upstream 400 + ``A maximum of 4 blocks with cache_control may be provided. Found N`` + via OpenRouter→Anthropic. - ``index: -1`` — pins the latest message so multi-turn savings compound: Anthropic-family providers use longest-matching-prefix lookup, so turn N+1 still reads turn N's cache up to the shared prefix. @@ -51,11 +63,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Two-breakpoint policy: system + latest message. See module docstring for -# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we -# use 2 here, leaving headroom for Phase-2 tool caching. +# Two-breakpoint policy: head-of-request + latest message. See module +# docstring for rationale. Anthropic caps requests at 4 ``cache_control`` +# blocks; we use 2 here, leaving headroom for Phase-2 tool caching. +# +# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's +# ``before_agent`` middlewares (priority, tree, memory, file-intent, +# anonymous-doc) insert ``SystemMessage`` instances into +# ``state["messages"]`` that accumulate across turns. With +# ``role: system`` the LiteLLM hook would tag *every* one of them with +# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0`` +# always targets the langchain-prepended ``request.system_message`` +# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text +# block), giving us exactly one stable cache breakpoint. _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( - {"location": "message", "role": "system"}, + {"location": "message", "index": 0}, {"location": "message", "index": -1}, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py index 095413bdb..c56db1528 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/create_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_create_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the create_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def create_confluence_page( title: str, @@ -42,160 +60,163 @@ def create_create_confluence_page_tool( """ logger.info(f"create_confluence_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - return {"status": "error", "message": context["error"]} + if "error" in context: + return {"status": "error", "message": context["error"]} - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Confluence accounts need re-authentication.", - "connector_type": "confluence", - } + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Confluence accounts need re-authentication.", + "connector_type": "confluence", + } - result = request_approval( - action_type="confluence_page_creation", - tool_name="create_confluence_page", - params={ - "title": title, - "content": content, - "space_id": space_id, - "connector_id": connector_id, - }, - context=context, - ) + result = request_approval( + action_type="confluence_page_creation", + tool_name="create_confluence_page", + params={ + "title": title, + "content": content, + "space_id": space_id, + "connector_id": connector_id, + }, + context=context, + ) - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) or "" - final_space_id = result.params.get("space_id", space_id) - final_connector_id = result.params.get("connector_id", connector_id) + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) or "" + final_space_id = result.params.get("space_id", space_id) + final_connector_id = result.params.get("connector_id", connector_id) - if not final_title or not final_title.strip(): - return {"status": "error", "message": "Page title cannot be empty."} - if not final_space_id: - return {"status": "error", "message": "A space must be selected."} + if not final_title or not final_title.strip(): + return {"status": "error", "message": "Page title cannot be empty."} + if not final_space_id: + return {"status": "error", "message": "A space must be selected."} - from sqlalchemy.future import select + from sqlalchemy.future import select - from app.db import SearchSourceConnector, SearchSourceConnectorType + from app.db import SearchSourceConnector, SearchSourceConnectorType - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Confluence connector found.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - api_result = await client.create_page( - space_id=final_space_id, - title=final_title, - body=final_content, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - page_id = str(api_result.get("id", "")) - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=page_id, - page_title=final_title, - space_id=final_space_id, - body_content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Confluence connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - return { - "status": "success", - "page_id": page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", - } + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + api_result = await client.create_page( + space_id=final_space_id, + title=final_title, + body=final_content, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_id = str(api_result.get("id", "")) + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=page_id, + page_title=final_title, + space_id=final_space_id, + body_content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "page_id": page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py index 7c03c2760..d4cd5032f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/delete_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_delete_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the delete_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def delete_confluence_page( page_title_or_id: str, @@ -43,137 +61,143 @@ def create_delete_confluence_page_tool( f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, page_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - page_data = context["page"] - page_id = page_data["page_id"] - page_title = page_data.get("page_title", "") - document_id = page_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_deletion", - tool_name="delete_confluence_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, page_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + page_title = page_data.get("page_title", "") + document_id = page_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_deletion", + tool_name="delete_confluence_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - await client.delete_page(final_page_id) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this page.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } - message = f"Confluence page '{page_title}' deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + await client.delete_page(final_page_id) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "page_id": final_page_id, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Confluence page '{page_title}' deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "page_id": final_page_id, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py index 791d0d8c5..51c205e00 100644 --- a/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/confluence/update_page.py @@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.confluence_history import ConfluenceHistoryConnector +from app.db import async_session_maker from app.services.confluence import ConfluenceToolMetadataService logger = logging.getLogger(__name__) @@ -18,6 +19,23 @@ def create_update_confluence_page_tool( user_id: str | None = None, connector_id: int | None = None, ): + """ + Factory function to create the update_confluence_page tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_confluence_page tool + """ + del db_session # per-call session — see docstring + @tool async def update_confluence_page( page_title_or_id: str, @@ -45,164 +63,168 @@ def create_update_confluence_page_tool( f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Confluence tool not properly configured.", } try: - metadata_service = ConfluenceToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title_or_id - ) + async with async_session_maker() as db_session: + metadata_service = ConfluenceToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title_or_id + ) - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "confluence", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + page_data = context["page"] + page_id = page_data["page_id"] + current_title = page_data["page_title"] + current_body = page_data.get("body", "") + current_version = page_data.get("version", 1) + document_id = page_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="confluence_page_update", + tool_name="update_confluence_page", + params={ + "page_id": page_id, + "document_id": document_id, + "new_title": new_title, + "new_content": new_content, + "version": current_version, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "confluence", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - page_data = context["page"] - page_id = page_data["page_id"] - current_title = page_data["page_title"] - current_body = page_data.get("body", "") - current_version = page_data.get("version", 1) - document_id = page_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="confluence_page_update", - tool_name="update_confluence_page", - params={ - "page_id": page_id, - "document_id": document_id, - "new_title": new_title, - "new_content": new_content, - "version": current_version, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_title = result.params.get("new_title", new_title) or current_title - final_content = result.params.get("new_content", new_content) - if final_content is None: - final_content = current_body - final_version = result.params.get("version", current_version) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this page.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + final_page_id = result.params.get("page_id", page_id) + final_title = result.params.get("new_title", new_title) or current_title + final_content = result.params.get("new_content", new_content) + if final_content is None: + final_content = current_body + final_version = result.params.get("version", current_version) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Confluence connector is invalid.", - } + final_document_id = result.params.get("document_id", document_id) - try: - client = ConfluenceHistoryConnector( - session=db_session, connector_id=final_connector_id - ) - api_result = await client.update_page( - page_id=final_page_id, - title=final_title, - body=final_content, - version_number=final_version + 1, - ) - await client.close() - except Exception as api_err: - if ( - "http 403" in str(api_err).lower() - or "status code 403" in str(api_err).lower() - ): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "No connector found for this page.", } - raise - page_links = ( - api_result.get("_links", {}) if isinstance(api_result, dict) else {} - ) - page_url = "" - if page_links.get("base") and page_links.get("webui"): - page_url = f"{page_links['base']}{page_links['webui']}" - - kb_message_suffix = "" - if final_document_id: - try: - from app.services.confluence import ConfluenceKBSyncService - - kb_service = ConfluenceKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - page_id=final_page_id, - user_id=user_id, - search_space_id=search_space_id, + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.CONFLUENCE_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Confluence connector is invalid.", + } + + try: + client = ConfluenceHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + api_result = await client.update_page( + page_id=final_page_id, + title=final_title, + body=final_content, + version_number=final_version + 1, + ) + await client.close() + except Exception as api_err: + if ( + "http 403" in str(api_err).lower() + or "status code 403" in str(api_err).lower() + ): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + page_links = ( + api_result.get("_links", {}) if isinstance(api_result, dict) else {} + ) + page_url = "" + if page_links.get("base") and page_links.get("webui"): + page_url = f"{page_links['base']}{page_links['webui']}" + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.confluence import ConfluenceKBSyncService + + kb_service = ConfluenceKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + page_id=final_page_id, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "page_id": final_page_id, - "page_url": page_url, - "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "page_id": final_page_id, + "page_url": page_url, + "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py index 5675a42e6..6420a90e6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker from app.services.mcp_oauth.registry import MCP_SERVICES logger = logging.getLogger(__name__) @@ -53,6 +53,23 @@ def create_get_connected_accounts_tool( search_space_id: int, user_id: str, ) -> StructuredTool: + """Factory function to create the get_connected_accounts tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to scope account discovery to. + user_id: User ID to scope account discovery to. + + Returns: + Configured StructuredTool for connected-accounts discovery. + """ + del db_session # per-call session — see docstring async def _run(service: str) -> list[dict[str, Any]]: svc_cfg = MCP_SERVICES.get(service) @@ -68,40 +85,41 @@ def create_get_connected_accounts_tool( except ValueError: return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == connector_type, + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == connector_type, + ) ) - ) - connectors = result.scalars().all() + connectors = result.scalars().all() - if not connectors: - return [ - { - "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + if not connectors: + return [ + { + "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + } + ] + + is_multi = len(connectors) > 1 + + accounts: list[dict[str, Any]] = [] + for conn in connectors: + cfg = conn.config or {} + entry: dict[str, Any] = { + "connector_id": conn.id, + "display_name": _extract_display_name(conn), + "service": service, } - ] + if is_multi: + entry["tool_prefix"] = f"{service}_{conn.id}" + for key in svc_cfg.account_metadata_keys: + if key in cfg: + entry[key] = cfg[key] + accounts.append(entry) - is_multi = len(connectors) > 1 - - accounts: list[dict[str, Any]] = [] - for conn in connectors: - cfg = conn.config or {} - entry: dict[str, Any] = { - "connector_id": conn.id, - "display_name": _extract_display_name(conn), - "service": service, - } - if is_multi: - entry["tool_prefix"] = f"{service}_{conn.id}" - for key in svc_cfg.account_metadata_keys: - if key in cfg: - entry[key] = cfg[key] - accounts.append(entry) - - return accounts + return accounts return StructuredTool( name="get_connected_accounts", diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py index 3cc99ac17..01159a261 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_discord_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_discord_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_discord_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_discord_channels() -> dict[str, Any]: """List text channels in the connected Discord server. @@ -22,59 +41,60 @@ def create_list_discord_channels_tool( Returns: Dictionary with status and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - guild_id = get_guild_id(connector) - if not guild_id: - return { - "status": "error", - "message": "No guild ID in Discord connector config.", - } - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/guilds/{guild_id}/channels", - headers={"Authorization": f"Bot {token}"}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + guild_id = get_guild_id(connector) + if not guild_id: + return { + "status": "error", + "message": "No guild ID in Discord connector config.", + } - # Type 0 = text channel - channels = [ - {"id": ch["id"], "name": ch["name"]} - for ch in resp.json() - if ch.get("type") == 0 - ] - return { - "status": "success", - "guild_id": guild_id, - "channels": channels, - "total": len(channels), - } + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/guilds/{guild_id}/channels", + headers={"Authorization": f"Bot {token}"}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + # Type 0 = text channel + channels = [ + {"id": ch["id"], "name": ch["name"]} + for ch in resp.json() + if ch.get("type") == 0 + ] + return { + "status": "success", + "guild_id": guild_id, + "channels": channels, + "total": len(channels), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py index d8bf989a1..88d6cdd49 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import DISCORD_API, get_bot_token, get_discord_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_discord_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_discord_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_discord_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_discord_messages( channel_id: str, @@ -30,7 +49,7 @@ def create_read_discord_messages_tool( Dictionary with status and a list of messages including id, author, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -39,55 +58,56 @@ def create_read_discord_messages_tool( limit = min(limit, 50) try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{DISCORD_API}/channels/{channel_id}/messages", - headers={"Authorization": f"Bot {token}"}, - params={"limit": limit}, - timeout=15.0, + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + token = get_bot_token(connector) - messages = [ - { - "id": m["id"], - "author": m.get("author", {}).get("username", "Unknown"), - "content": m.get("content", ""), - "timestamp": m.get("timestamp", ""), - } - for m in resp.json() - ] + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/channels/{channel_id}/messages", + headers={"Authorization": f"Bot {token}"}, + params={"limit": limit}, + timeout=15.0, + ) - return { - "status": "success", - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + messages = [ + { + "id": m["id"], + "author": m.get("author", {}).get("username", "Unknown"), + "content": m.get("content", ""), + "timestamp": m.get("timestamp", ""), + } + for m in resp.json() + ] + + return { + "status": "success", + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py index 236cd017a..5fe6fde35 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import DISCORD_API, get_bot_token, get_discord_connector @@ -17,6 +18,23 @@ def create_send_discord_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_discord_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_discord_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_discord_message( channel_id: str, @@ -34,7 +52,7 @@ def create_send_discord_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Discord tool not properly configured.", @@ -47,64 +65,65 @@ def create_send_discord_message_tool( } try: - connector = await get_discord_connector( - db_session, search_space_id, user_id - ) - if not connector: - return {"status": "error", "message": "No Discord connector found."} + async with async_session_maker() as db_session: + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Discord connector found."} - result = request_approval( - action_type="discord_send_message", - tool_name="send_discord_message", - params={"channel_id": channel_id, "content": content}, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_channel = result.params.get("channel_id", channel_id) - - token = get_bot_token(connector) - - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{DISCORD_API}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bot {token}", - "Content-Type": "application/json", - }, - json={"content": final_content}, - timeout=15.0, + result = request_approval( + action_type="discord_send_message", + tool_name="send_discord_message", + params={"channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Discord bot token is invalid.", - "connector_type": "discord", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Bot lacks permission to send messages in this channel.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Discord API error: {resp.status_code}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": f"Message sent to channel {final_channel}.", - } + final_content = result.params.get("content", content) + final_channel = result.params.get("channel_id", channel_id) + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{DISCORD_API}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + }, + json={"content": final_content}, + timeout=15.0, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Bot lacks permission to send messages in this channel.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to channel {final_channel}.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py index 22d8a8a27..7aae034cc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.dropbox.client import DropboxClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -59,6 +59,23 @@ def create_create_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_dropbox_file( name: str, @@ -82,184 +99,191 @@ def create_create_dropbox_file_tool( f"create_dropbox_file called: name='{name}', file_type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Dropbox accounts need re-authentication.", - "connector_type": "dropbox", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = DropboxClient(session=db_session, connector_id=cid) - items, err = await client.list_folder("") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - { - "folder_path": item.get("path_lower", ""), - "name": item["name"], - } - for item in items - if item.get(".tag") == "folder" and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - "supported_types": _SUPPORTED_TYPES, - } - - result = request_approval( - action_type="dropbox_file_creation", - tool_name="create_dropbox_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_path": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_path = result.params.get("parent_folder_path") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_extension(final_name, final_file_type) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.DROPBOX_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected Dropbox connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Dropbox accounts need re-authentication.", + "connector_type": "dropbox", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = DropboxClient(session=db_session, connector_id=cid) + items, err = await client.list_folder("") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + { + "folder_path": item.get("path_lower", ""), + "name": item["name"], + } + for item in items + if item.get(".tag") == "folder" and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, + "supported_types": _SUPPORTED_TYPES, } - client = DropboxClient(session=db_session, connector_id=connector.id) - - parent_path = final_parent_folder_path or "" - file_path = ( - f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" - ) - - if final_file_type == "paper": - created = await client.create_paper_doc(file_path, final_content or "") - file_id = created.get("file_id", "") - web_url = created.get("url", "") - else: - docx_bytes = _markdown_to_docx(final_content or "") - created = await client.upload_file( - file_path, docx_bytes, mode="add", autorename=True + result = request_approval( + action_type="dropbox_file_creation", + tool_name="create_dropbox_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_path": None, + }, + context=context, ) - file_id = created.get("id", "") - web_url = "" - logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_message_suffix = "" - try: - from app.services.dropbox import DropboxKBSyncService + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_path = result.params.get("parent_folder_path") - kb_service = DropboxKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=file_id, - file_name=final_name, - file_path=file_path, - web_url=web_url, - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_extension(final_name, final_file_type) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + connector = result.scalars().first() else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + connector = connectors[0] - return { - "status": "success", - "file_id": file_id, - "name": final_name, - "web_url": web_url, - "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", - } + if not connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid.", + } + + client = DropboxClient(session=db_session, connector_id=connector.id) + + parent_path = final_parent_folder_path or "" + file_path = ( + f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" + ) + + if final_file_type == "paper": + created = await client.create_paper_doc( + file_path, final_content or "" + ) + file_id = created.get("file_id", "") + web_url = created.get("url", "") + else: + docx_bytes = _markdown_to_docx(final_content or "") + created = await client.upload_file( + file_path, docx_bytes, mode="add", autorename=True + ) + file_id = created.get("id", "") + web_url = "" + + logger.info(f"Dropbox file created: id={file_id}, name={final_name}") + + kb_message_suffix = "" + try: + from app.services.dropbox import DropboxKBSyncService + + kb_service = DropboxKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=file_id, + file_name=final_name, + file_path=file_path, + web_url=web_url, + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": file_id, + "name": final_name, + "web_url": web_url, + "message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py index 12559b57a..0e59e49db 100644 --- a/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/dropbox/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_dropbox_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_dropbox_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_dropbox_file( file_name: str, @@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool( f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Dropbox tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.DROPBOX_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.DROPBOX_FILE, - func.lower( - cast( - Document.document_metadata["dropbox_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed Dropbox files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_path = meta.get("dropbox_path") - file_id = meta.get("dropbox_file_id") - document_id = document.id - - if not file_path: - return { - "status": "error", - "message": "File path is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.DROPBOX_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.DROPBOX_FILE, + func.lower( + cast( + Document.document_metadata["dropbox_file_name"], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Dropbox connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "dropbox", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed Dropbox files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "file_path": file_path, - "name": file_name, - "document_id": document_id, - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="dropbox_file_trash", - tool_name="delete_dropbox_file", - params={ - "file_path": file_path, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_path = meta.get("dropbox_path") + file_id = meta.get("dropbox_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_path: + return { + "status": "error", + "message": "File path is missing. Please re-index the file.", + } - final_file_path = result.params.get("file_path", file_path) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected Dropbox connector is invalid or has been disconnected.", + "message": "Dropbox connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "dropbox", + } - client = DropboxClient(session=db_session, connector_id=actual_connector_id) - await client.delete_file(final_file_path) + context = { + "file": { + "file_id": file_id, + "file_path": file_path, + "name": file_name, + "document_id": document_id, + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info(f"Dropbox file deleted: path={final_file_path}") - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": file_id, - "message": f"Successfully deleted '{file_name}' from Dropbox.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File deleted, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="dropbox_file_trash", + tool_name="delete_dropbox_file", + params={ + "file_path": file_path, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_path = result.params.get("file_path", file_path) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DROPBOX_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected Dropbox connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" + ) + + client = DropboxClient( + session=db_session, connector_id=actual_connector_id + ) + await client.delete_file(final_file_path) + + logger.info(f"Dropbox file deleted: path={final_file_path}") + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": file_id, + "message": f"Successfully deleted '{file_name}' from Dropbox.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File deleted, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py index 7e9ddf7d3..c88b48d2d 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/create_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def create_gmail_draft( to: str, @@ -57,267 +75,276 @@ def create_create_gmail_draft_tool( """ logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_draft_creation", - tool_name="create_gmail_draft", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + logger.info( + f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.", + } + + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" ) - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - created, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_CREATE_EMAIL_DRAFT", - { - "user_id": "me", - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(created, dict): - created = {} + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } else: - from googleapiclient.discovery import build + from google.oauth2.credentials import Credentials - gmail_service = build("gmail", "v1", credentials=creds) - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError + from app.config import config + from app.utils.oauth_security import TokenEncryption - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + created, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_CREATE_EMAIL_DRAFT", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(created, dict): + created = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .create(userId="me", body={"message": {"raw": raw}}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info(f"Gmail draft created: id={created.get('id')}") + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService + logger.info(f"Gmail draft created: id={created.get('id')}") - kb_service = GmailKBSyncService(db_session) - draft_message = created.get("message", {}) - kb_result = await kb_service.sync_after_create( - message_id=draft_message.get("id", ""), - thread_id=draft_message.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - draft_id=created.get("id"), - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + draft_message = created.get("message", {}) + kb_result = await kb_service.sync_after_create( + message_id=draft_message.get("id", ""), + thread_id=draft_message.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + draft_id=created.get("id"), + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "draft_id": created.get("id"), - "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + return { + "status": "success", + "draft_id": created.get("id"), + "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index 1964181e4..464713591 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -5,7 +5,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -20,6 +20,23 @@ def create_read_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def read_gmail_email(message_id: str) -> dict[str, Any]: """Read the full content of a specific Gmail email by its message ID. @@ -32,108 +49,115 @@ def create_read_gmail_email_tool( Returns: Dictionary with status and the full email content formatted as markdown. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + connector = result.scalars().first() + if not connector: return { "status": "error", - "message": "Composio connected account ID not found.", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + service = ComposioService() + detail, error = await service.get_gmail_message_detail( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if error: + return {"status": "error", "message": error} + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + summary = _format_gmail_summary(detail) + content = ( + f"# {summary['subject']}\n\n" + f"**From:** {summary['from']}\n" + f"**To:** {summary['to']}\n" + f"**Date:** {summary['date']}\n\n" + f"## Message Content\n\n" + f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {summary['message_id']}\n" + f"- **Thread ID:** {summary['thread_id']}\n" + ) + return { + "status": "success", + "message_id": summary["message_id"] or message_id, + "content": content, } from app.agents.new_chat.tools.gmail.search_emails import ( - _format_gmail_summary, + _build_credentials, ) - from app.services.composio_service import ComposioService - service = ComposioService() - detail, error = await service.get_gmail_message_detail( - connected_account_id=cca_id, - entity_id=f"surfsense_{user_id}", - message_id=message_id, + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, ) + + detail, error = await gmail.get_message_details(message_id) if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } return {"status": "error", "message": error} + if not detail: return { "status": "not_found", "message": f"Email with ID '{message_id}' not found.", } - summary = _format_gmail_summary(detail) - content = ( - f"# {summary['subject']}\n\n" - f"**From:** {summary['from']}\n" - f"**To:** {summary['to']}\n" - f"**Date:** {summary['date']}\n\n" - f"## Message Content\n\n" - f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" - f"## Message Details\n\n" - f"- **Message ID:** {summary['message_id']}\n" - f"- **Thread ID:** {summary['thread_id']}\n" - ) + content = gmail.format_message_to_markdown(detail) + return { "status": "success", - "message_id": summary["message_id"] or message_id, + "message_id": message_id, "content": content, } - from app.agents.new_chat.tools.gmail.search_emails import _build_credentials - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - detail, error = await gmail.get_message_details(message_id) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): - return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", - } - return {"status": "error", "message": error} - - if not detail: - return { - "status": "not_found", - "message": f"Email with ID '{message_id}' not found.", - } - - content = gmail.format_message_to_markdown(detail) - - return {"status": "success", "message_id": message_id, "content": content} - except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index 59886159a..3ce154c53 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -6,7 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -124,6 +124,23 @@ def create_search_gmail_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_gmail tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_gmail tool + """ + del db_session # per-call session — see docstring + @tool async def search_gmail( query: str, @@ -142,91 +159,92 @@ def create_search_gmail_tool( Dictionary with status and a list of email summaries including message_id, subject, from, date, snippet. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Gmail tool not properly configured."} max_results = min(max_results, 20) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ): - return await _search_composio_gmail( - connector, str(user_id), query, max_results - ) - - creds = _build_credentials(connector) - - from app.connectors.google_gmail_connector import GoogleGmailConnector - - gmail = GoogleGmailConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - messages_list, error = await gmail.get_messages_list( - max_results=max_results, query=query - ) - if error: - if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() - ): + connector = result.scalars().first() + if not connector: return { - "status": "auth_error", - "message": error, - "connector_type": "gmail", + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - return {"status": "error", "message": error} - if not messages_list: - return { - "status": "success", - "emails": [], - "total": 0, - "message": "No emails found.", - } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + return await _search_composio_gmail( + connector, str(user_id), query, max_results + ) - emails = [] - for msg in messages_list: - detail, err = await gmail.get_message_details(msg["id"]) - if err: - continue - headers = { - h["name"].lower(): h["value"] - for h in detail.get("payload", {}).get("headers", []) - } - emails.append( - { - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - } + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, ) - return {"status": "success", "emails": emails, "total": len(emails)} + messages_list, error = await gmail.get_messages_list( + max_results=max_results, query=query + ) + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } + return {"status": "error", "message": error} + + if not messages_list: + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } + + emails = [] + for msg in messages_list: + detail, err = await gmail.get_message_details(msg["id"]) + if err: + continue + headers = { + h["name"].lower(): h["value"] + for h in detail.get("payload", {}).get("headers", []) + } + emails.append( + { + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + } + ) + + return {"status": "success", "emails": emails, "total": len(emails)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py index 79ff2d9c7..4d5aa3bcc 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_send_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def send_gmail_email( to: str, @@ -58,268 +76,277 @@ def create_send_gmail_email_tool( """ logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - logger.info( - f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" - ) - result = request_approval( - action_type="gmail_email_send", - tool_name="send_gmail_email", - params={ - "to": to, - "subject": subject, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", to) - final_subject = result.params.get("subject", subject) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get("connector_id") - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Gmail accounts have expired authentication") return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), + + logger.info( + f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" + ) + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={ + "to": to, + "subject": subject, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": None, + }, + context=context, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", + } + + final_to = result.params.get("to", to) + final_subject = result.params.get("subject", subject) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get("connector_id") + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + actual_connector_id = connector.id - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" ) - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - sent, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_SEND_EMAIL", - { - "user_id": "me", - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(sent, dict): - sent = {} + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } else: - from googleapiclient.discovery import build + from google.oauth2.credentials import Credentials - gmail_service = build("gmail", "v1", credentials=creds) - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError + from app.config import config + from app.utils.oauth_security import TokenEncryption - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + + sent, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_SEND_EMAIL", + { + "user_id": "me", + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(sent, dict): + sent = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + sent = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .send(userId="me", body={"raw": raw}) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info( - f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" - ) + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - kb_message_suffix = "" - try: - from app.services.gmail import GmailKBSyncService - - kb_service = GmailKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - message_id=sent.get("id", ""), - thread_id=sent.get("threadId", ""), - subject=final_subject, - sender="me", - date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - body_text=final_body, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after send failed: {kb_err}") - kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "message_id": sent.get("id"), - "thread_id": sent.get("threadId"), - "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", - } + kb_message_suffix = "" + try: + from app.services.gmail import GmailKBSyncService + + kb_service = GmailKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + message_id=sent.get("id", ""), + thread_id=sent.get("threadId", ""), + subject=final_subject, + sender="me", + date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + body_text=final_body, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after send failed: {kb_err}") + kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "message_id": sent.get("id"), + "thread_id": sent.get("threadId"), + "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py index 4e710dc72..95f5b4e6c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py @@ -7,6 +7,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,23 @@ def create_trash_gmail_email_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the trash_gmail_email tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured trash_gmail_email tool + """ + del db_session # per-call session — see docstring + @tool async def trash_gmail_email( email_subject_or_id: str, @@ -55,254 +73,261 @@ def create_trash_gmail_email_tool( f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, email_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Email not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, email_subject_or_id ) - return { - "status": "auth_error", - "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Email not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not message_id: - return { - "status": "error", - "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Gmail account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", + } - logger.info( - f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="gmail_email_trash", - tool_name="trash_gmail_email", - params={ - "message_id": message_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", - } - - final_message_id = result.params.get("message_id", message_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this email.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + if not message_id: return { "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", + "message": "Message ID is missing from the indexed document. Please re-index the email and try again.", } - else: - from google.oauth2.credentials import Credentials - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="gmail_email_trash", + tool_name="trash_gmail_email", + params={ + "message_id": message_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - ) - - _trashed, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_MOVE_TO_TRASH", - {"user_id": "me", "message_id": final_message_id}, - ) - if error: - raise RuntimeError(error) - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Gmail email trashed: message_id={final_message_id}") - - trash_result: dict[str, Any] = { - "status": "success", - "message_id": final_message_id, - "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"Email trashed, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + final_message_id = result.params.get("message_id", message_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return trash_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this email.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + ) + + _trashed, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_MOVE_TO_TRASH", + {"user_id": "me", "message_id": final_message_id}, + ) + if error: + raise RuntimeError(error) + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .trash(userId="me", id=final_message_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Gmail email trashed: message_id={final_message_id}") + + trash_result: dict[str, Any] = { + "status": "success", + "message_id": final_message_id, + "message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"Email trashed, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py index 50956f03a..129b7defb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/update_draft.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_update_gmail_draft_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_gmail_draft tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_gmail_draft tool + """ + del db_session # per-call session — see docstring + @tool async def update_gmail_draft( draft_subject_or_id: str, @@ -76,324 +94,329 @@ def create_update_gmail_draft_tool( f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Gmail tool not properly configured. Please contact support.", } try: - metadata_service = GmailToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, draft_subject_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Draft not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Gmail account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } - - email = context["email"] - message_id = email["message_id"] - document_id = email.get("document_id") - connector_id_from_context = account["id"] - draft_id_from_context = context.get("draft_id") - - original_subject = email.get("subject", draft_subject_or_id) - final_subject_default = subject if subject else original_subject - final_to_default = to if to else "" - - logger.info( - f"Requesting approval for updating Gmail draft: '{original_subject}' " - f"(message_id={message_id}, draft_id={draft_id_from_context})" - ) - result = request_approval( - action_type="gmail_draft_update", - tool_name="update_gmail_draft", - params={ - "message_id": message_id, - "draft_id": draft_id_from_context, - "to": final_to_default, - "subject": final_subject_default, - "body": body, - "cc": cc, - "bcc": bcc, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", - } - - final_to = result.params.get("to", final_to_default) - final_subject = result.params.get("subject", final_subject_default) - final_body = result.params.get("body", body) - final_cc = result.params.get("cc", cc) - final_bcc = result.params.get("bcc", bcc) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_draft_id = result.params.get("draft_id", draft_id_from_context) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this draft.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _gmail_types = [ - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_gmail_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } - - logger.info( - f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" - ) - - is_composio_gmail = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR - ) - if is_composio_gmail: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } - else: - from google.oauth2.credentials import Credentials - - from app.config import config - from app.utils.oauth_security import TokenEncryption - - config_data = dict(connector.config) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - token_encryption = TokenEncryption(config.SECRET_KEY) - if config_data.get("token"): - config_data["token"] = token_encryption.decrypt_token( - config_data["token"] - ) - if config_data.get("refresh_token"): - config_data["refresh_token"] = token_encryption.decrypt_token( - config_data["refresh_token"] - ) - if config_data.get("client_secret"): - config_data["client_secret"] = token_encryption.decrypt_token( - config_data["client_secret"] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + async with async_session_maker() as db_session: + metadata_service = GmailToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, draft_subject_or_id ) - # Resolve draft_id if not already available - if not final_draft_id: - logger.info( - f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" - ) - if is_composio_gmail: - final_draft_id = await _find_composio_draft_id_by_message( - connector, user_id, message_id - ) - else: - from googleapiclient.discovery import build + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Draft not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - gmail_service = build("gmail", "v1", credentials=creds) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) - - if not final_draft_id: - return { - "status": "error", - "message": ( - "Could not find this draft in Gmail. " - "It may have already been sent or deleted." - ), - } - - message = MIMEText(final_body) - if final_to: - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - if is_composio_gmail: - from app.agents.new_chat.tools.gmail.composio_helpers import ( - execute_composio_gmail_tool, - split_recipients, - ) - - updated, error = await execute_composio_gmail_tool( - connector, - user_id, - "GMAIL_UPDATE_DRAFT", - { - "user_id": "me", - "draft_id": final_draft_id, - "recipient_email": final_to, - "subject": final_subject, - "body": final_body, - "cc": split_recipients(final_cc), - "bcc": split_recipients(final_bcc), - "is_html": False, - }, - ) - if error: - raise RuntimeError(error) - if not isinstance(updated, dict): - updated = {} - else: - from googleapiclient.discovery import build - - gmail_service = build("gmail", "v1", credentials=creds) - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" + "Gmail account %s has expired authentication", + account.get("id"), ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", } - if isinstance(api_err, HttpError) and api_err.resp.status == 404: + + email = context["email"] + message_id = email["message_id"] + document_id = email.get("document_id") + connector_id_from_context = account["id"] + draft_id_from_context = context.get("draft_id") + + original_subject = email.get("subject", draft_subject_or_id) + final_subject_default = subject if subject else original_subject + final_to_default = to if to else "" + + logger.info( + f"Requesting approval for updating Gmail draft: '{original_subject}' " + f"(message_id={message_id}, draft_id={draft_id_from_context})" + ) + result = request_approval( + action_type="gmail_draft_update", + tool_name="update_gmail_draft", + params={ + "message_id": message_id, + "draft_id": draft_id_from_context, + "to": final_to_default, + "subject": final_subject_default, + "body": body, + "cc": cc, + "bcc": bcc, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.", + } + + final_to = result.params.get("to", final_to_default) + final_subject = result.params.get("subject", final_subject_default) + final_body = result.params.get("body", body) + final_cc = result.params.get("cc", cc) + final_bcc = result.params.get("bcc", bcc) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_draft_id = result.params.get("draft_id", draft_id_from_context) + + if not final_connector_id: return { "status": "error", - "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + "message": "No connector found for this draft.", } - raise - logger.info(f"Gmail draft updated: id={updated.get('id')}") + from sqlalchemy.future import select - kb_message_suffix = "" - if document_id: - try: - from sqlalchemy.future import select as sa_select - from sqlalchemy.orm.attributes import flag_modified + from app.db import SearchSourceConnector, SearchSourceConnectorType - from app.db import Document + _gmail_types = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + ] - doc_result = await db_session.execute( - sa_select(Document).filter(Document.id == document_id) + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_gmail_types), ) - document = doc_result.scalars().first() - if document: - document.source_markdown = final_body - document.title = final_subject - meta = dict(document.document_metadata or {}) - meta["subject"] = final_subject - meta["draft_id"] = updated.get("id", final_draft_id) - updated_msg = updated.get("message", {}) - if updated_msg.get("id"): - meta["message_id"] = updated_msg["id"] - document.document_metadata = meta - flag_modified(document, "document_metadata") - await db_session.commit() - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - logger.info( - f"KB document {document_id} updated for draft {final_draft_id}" + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Gmail connector is invalid or has been disconnected.", + } + + logger.info( + f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" + ) + + is_composio_gmail = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ) + if is_composio_gmail: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + else: + from google.oauth2.credentials import Credentials + + from app.config import config + from app.utils.oauth_security import TokenEncryption + + config_data = dict(connector.config) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = ( + token_encryption.decrypt_token( + config_data["refresh_token"] + ) + ) + if config_data.get("client_secret"): + config_data["client_secret"] = ( + token_encryption.decrypt_token( + config_data["client_secret"] + ) + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + # Resolve draft_id if not already available + if not final_draft_id: + logger.info( + f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" + ) + if is_composio_gmail: + final_draft_id = await _find_composio_draft_id_by_message( + connector, user_id, message_id ) else: - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB update after draft edit failed: {kb_err}") - await db_session.rollback() - kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + from googleapiclient.discovery import build - return { - "status": "success", - "draft_id": updated.get("id"), - "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", - } + gmail_service = build("gmail", "v1", credentials=creds) + final_draft_id = await _find_draft_id_by_message( + gmail_service, message_id + ) + + if not final_draft_id: + return { + "status": "error", + "message": ( + "Could not find this draft in Gmail. " + "It may have already been sent or deleted." + ), + } + + message = MIMEText(final_body) + if final_to: + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + if is_composio_gmail: + from app.agents.new_chat.tools.gmail.composio_helpers import ( + execute_composio_gmail_tool, + split_recipients, + ) + + updated, error = await execute_composio_gmail_tool( + connector, + user_id, + "GMAIL_UPDATE_DRAFT", + { + "user_id": "me", + "draft_id": final_draft_id, + "recipient_email": final_to, + "subject": final_subject, + "body": final_body, + "cc": split_recipients(final_cc), + "bcc": split_recipients(final_bcc), + "is_html": False, + }, + ) + if error: + raise RuntimeError(error) + if not isinstance(updated, dict): + updated = {} + else: + from googleapiclient.discovery import build + + gmail_service = build("gmail", "v1", credentials=creds) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .update( + userId="me", + id=final_draft_id, + body={"message": {"raw": raw}}, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + if isinstance(api_err, HttpError) and api_err.resp.status == 404: + return { + "status": "error", + "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + } + raise + + logger.info(f"Gmail draft updated: id={updated.get('id')}") + + kb_message_suffix = "" + if document_id: + try: + from sqlalchemy.future import select as sa_select + from sqlalchemy.orm.attributes import flag_modified + + from app.db import Document + + doc_result = await db_session.execute( + sa_select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + document.source_markdown = final_body + document.title = final_subject + meta = dict(document.document_metadata or {}) + meta["subject"] = final_subject + meta["draft_id"] = updated.get("id", final_draft_id) + updated_msg = updated.get("message", {}) + if updated_msg.get("id"): + meta["message_id"] = updated_msg["id"] + document.document_metadata = meta + flag_modified(document, "document_metadata") + await db_session.commit() + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + logger.info( + f"KB document {document_id} updated for draft {final_draft_id}" + ) + else: + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB update after draft edit failed: {kb_err}") + await db_session.rollback() + kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." + + return { + "status": "success", + "draft_id": updated.get("id"), + "message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py index 0a4720f6f..dec92cc8b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/create_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_create_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_calendar_event( summary: str, @@ -60,284 +78,294 @@ def create_create_calendar_event_tool( f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning( - "All Google Calendar accounts have expired authentication" + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - return { - "status": "auth_error", - "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - logger.info( - f"Requesting approval for creating calendar event: summary='{summary}'" - ) - result = request_approval( - action_type="google_calendar_event_creation", - tool_name="create_calendar_event", - params={ - "summary": summary, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "description": description, - "location": location, - "attendees": attendees, - "timezone": context.get("timezone"), - "connector_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", - } - - final_summary = result.params.get("summary", summary) - final_start_datetime = result.params.get("start_datetime", start_datetime) - final_end_datetime = result.params.get("end_datetime", end_datetime) - final_description = result.params.get("description", description) - final_location = result.params.get("location", location) - final_attendees = result.params.get("attendees", attendees) - final_connector_id = result.params.get("connector_id") - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Event summary cannot be empty."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning( + "All Google Calendar accounts have expired authentication" ) + return { + "status": "auth_error", + "message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } + + logger.info( + f"Requesting approval for creating calendar event: summary='{summary}'" ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - actual_connector_id = connector.id - - logger.info( - f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" - ) - - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this connector.", - } - else: - config_data = dict(connector.config) - - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + result = request_approval( + action_type="google_calendar_event_creation", + tool_name="create_calendar_event", + params={ + "summary": summary, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "description": description, + "location": location, + "attendees": attendees, + "timezone": context.get("timezone"), + "connector_id": None, + }, + context=context, ) - tz = context.get("timezone", "UTC") - event_body: dict[str, Any] = { - "summary": final_summary, - "start": {"dateTime": final_start_datetime, "timeZone": tz}, - "end": {"dateTime": final_end_datetime, "timeZone": tz}, - } - if final_description: - event_body["description"] = final_description - if final_location: - event_body["location"] = final_location - if final_attendees: - event_body["attendees"] = [ - {"email": e.strip()} for e in final_attendees if e.strip() + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not created. Do not ask again or suggest alternatives.", + } + + final_summary = result.params.get("summary", summary) + final_start_datetime = result.params.get( + "start_datetime", start_datetime + ) + final_end_datetime = result.params.get("end_datetime", end_datetime) + final_description = result.params.get("description", description) + final_location = result.params.get("location", location) + final_attendees = result.params.get("attendees", attendees) + final_connector_id = result.params.get("connector_id") + + if not final_summary or not final_summary.strip(): + return { + "status": "error", + "message": "Event summary cannot be empty.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_params = { - "calendar_id": "primary", - "summary": final_summary, - "start_datetime": final_start_datetime, - "end_datetime": final_end_datetime, - "timezone": tz, - "attendees": final_attendees or [], - } - if final_description: - composio_params["description"] = final_description - if final_location: - composio_params["location"] = final_location - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_CREATE_EVENT", - params=composio_params, - entity_id=f"surfsense_{user_id}", - ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), ) - created = composio_result.get("data", {}) - if isinstance(created, dict): - created = created.get("data", created) - if isinstance(created, dict): - created = created.get("response_data", created) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } + actual_connector_id = connector.id - logger.info( - f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_calendar import GoogleCalendarKBSyncService - - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - event_id=created.get("id"), - event_summary=final_summary, - calendar_id="primary", - start_time=final_start_datetime, - end_time=final_end_datetime, - location=final_location, - html_link=created.get("htmlLink"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "event_id": created.get("id"), - "html_link": created.get("htmlLink"), - "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", - } + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + tz = context.get("timezone", "UTC") + event_body: dict[str, Any] = { + "summary": final_summary, + "start": {"dateTime": final_start_datetime, "timeZone": tz}, + "end": {"dateTime": final_end_datetime, "timeZone": tz}, + } + if final_description: + event_body["description"] = final_description + if final_location: + event_body["location"] = final_location + if final_attendees: + event_body["attendees"] = [ + {"email": e.strip()} for e in final_attendees if e.strip() + ] + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params = { + "calendar_id": "primary", + "summary": final_summary, + "start_datetime": final_start_datetime, + "end_datetime": final_end_datetime, + "timezone": tz, + "attendees": final_attendees or [], + } + if final_description: + composio_params["description"] = final_description + if final_location: + composio_params["location"] = final_location + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_CREATE_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + created = composio_result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .insert(calendarId="primary", body=event_body) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_calendar import GoogleCalendarKBSyncService + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + event_id=created.get("id"), + event_summary=final_summary, + calendar_id="primary", + start_time=final_start_datetime, + end_time=final_end_datetime, + location=final_location, + html_link=created.get("htmlLink"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "event_id": created.get("id"), + "html_link": created.get("htmlLink"), + "message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py index 53596ac0f..e7e891b08 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/delete_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,23 @@ def create_delete_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def delete_calendar_event( event_title_or_id: str, @@ -54,252 +72,258 @@ def create_delete_calendar_event_tool( f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch deletion context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Calendar account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, event_title_or_id ) - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch deletion context: {error_msg}") + return {"status": "error", "message": error_msg} - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Google Calendar account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - logger.info( - f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_calendar_event_deletion", - tool_name="delete_calendar_event", - params={ - "event_id": event_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - - actual_connector_id = connector.id - - logger.info( - f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) - - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_calendar_event_deletion", + tool_name="delete_calendar_event", + params={ + "event_id": event_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_DELETE_EVENT", - params={"calendar_id": "primary", "event_id": final_event_id}, - entity_id=f"surfsense_{user_id}", - ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) - ) - else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.", } - raise - logger.info(f"Calendar event deleted: event_id={final_event_id}") - - delete_result: dict[str, Any] = { - "status": "success", - "event_id": final_event_id, - "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - delete_result["warning"] = ( - f"Event deleted, but failed to remove from knowledge base: {e!s}" - ) - - delete_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - delete_result["message"] = ( - f"{delete_result.get('message', '')} (also removed from knowledge base)" + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb ) - return delete_result + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", + } + + actual_connector_id = connector.id + + logger.info( + f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) + + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_DELETE_EVENT", + params={ + "calendar_id": "primary", + "event_id": final_event_id, + }, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .delete(calendarId="primary", eventId=final_event_id) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event deleted: event_id={final_event_id}") + + delete_result: dict[str, Any] = { + "status": "success", + "event_id": final_event_id, + "message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + delete_result["warning"] = ( + f"Event deleted, but failed to remove from knowledge base: {e!s}" + ) + + delete_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + delete_result["message"] = ( + f"{delete_result.get('message', '')} (also removed from knowledge base)" + ) + + return delete_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index b5194d15f..e5f18f675 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.agents.new_chat.tools.gmail.search_emails import _build_credentials -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -50,6 +50,23 @@ def create_search_calendar_events_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the search_calendar_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured search_calendar_events tool + """ + del db_session # per-call session — see docstring + @tool async def search_calendar_events( start_date: str, @@ -67,7 +84,7 @@ def create_search_calendar_events_tool( Dictionary with status and a list of events including event_id, summary, start, end, location, attendees. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Calendar tool not properly configured.", @@ -76,84 +93,85 @@ def create_search_calendar_events_tool( max_results = min(max_results, 50) try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + ) ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", - } - - if ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + connector = result.scalars().first() + if not connector: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", } - from app.services.composio_service import ComposioService - - events_raw, error = await ComposioService().get_calendar_events( - connected_account_id=cca_id, - entity_id=f"surfsense_{user_id}", - time_min=_to_calendar_boundary(start_date, is_end=False), - time_max=_to_calendar_boundary(end_date, is_end=True), - max_results=max_results, - ) - if not events_raw and not error: - error = "No events found in the specified date range." - else: - creds = _build_credentials(connector) - - from app.connectors.google_calendar_connector import ( - GoogleCalendarConnector, - ) - - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) - - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) - - if error: if ( - "re-authenticate" in error.lower() - or "authentication failed" in error.lower() + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - return { - "status": "auth_error", - "message": error, - "connector_type": "google_calendar", - } - if "no events found" in error.lower(): - return { - "status": "success", - "events": [], - "total": 0, - "message": error, - } - return {"status": "error", "message": error} + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } - events = _format_calendar_events(events_raw) + from app.services.composio_service import ComposioService - return {"status": "success", "events": events, "total": len(events)} + events_raw, error = await ComposioService().get_calendar_events( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + time_min=_to_calendar_boundary(start_date, is_end=False), + time_max=_to_calendar_boundary(end_date, is_end=True), + max_results=max_results, + ) + if not events_raw and not error: + error = "No events found in the specified date range." + else: + creds = _build_credentials(connector) + + from app.connectors.google_calendar_connector import ( + GoogleCalendarConnector, + ) + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + if error: + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "google_calendar", + } + if "no events found" in error.lower(): + return { + "status": "success", + "events": [], + "total": 0, + "message": error, + } + return {"status": "error", "message": error} + + events = _format_calendar_events(events_raw) + + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 1dba36c20..b8561fee6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -9,6 +9,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) @@ -33,6 +34,23 @@ def create_update_calendar_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the update_calendar_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured update_calendar_event tool + """ + del db_session # per-call session — see docstring + @tool async def update_calendar_event( event_title_or_id: str, @@ -74,312 +92,317 @@ def create_update_calendar_event_tool( """ logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Calendar tool not properly configured. Please contact support.", } try: - metadata_service = GoogleCalendarToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, event_title_or_id - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"Event not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - if context.get("auth_expired"): - logger.warning("Google Calendar account has expired authentication") - return { - "status": "auth_error", - "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_calendar", - } - - event = context["event"] - event_id = event["event_id"] - document_id = event.get("document_id") - connector_id_from_context = context["account"]["id"] - - if not event_id: - return { - "status": "error", - "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", - } - - logger.info( - f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" - ) - result = request_approval( - action_type="google_calendar_event_update", - tool_name="update_calendar_event", - params={ - "event_id": event_id, - "document_id": document_id, - "connector_id": connector_id_from_context, - "new_summary": new_summary, - "new_start_datetime": new_start_datetime, - "new_end_datetime": new_end_datetime, - "new_description": new_description, - "new_location": new_location, - "new_attendees": new_attendees, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", - } - - final_event_id = result.params.get("event_id", event_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_new_summary = result.params.get("new_summary", new_summary) - final_new_start_datetime = result.params.get( - "new_start_datetime", new_start_datetime - ) - final_new_end_datetime = result.params.get( - "new_end_datetime", new_end_datetime - ) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_location = result.params.get("new_location", new_location) - final_new_attendees = result.params.get("new_attendees", new_attendees) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this event.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _calendar_types = [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_calendar_types), + async with async_session_maker() as db_session: + metadata_service = GoogleCalendarToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, event_title_or_id ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Calendar connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"Event not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - logger.info( - f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" - ) + if context.get("auth_expired"): + logger.warning("Google Calendar account has expired authentication") + return { + "status": "auth_error", + "message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_calendar", + } - is_composio_calendar = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ) - if is_composio_calendar: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + event = context["event"] + event_id = event["event_id"] + document_id = event.get("document_id") + connector_id_from_context = context["account"]["id"] + + if not event_id: return { "status": "error", - "message": "Composio connected account ID not found for this connector.", + "message": "Event ID is missing from the indexed document. Please re-index the event and try again.", } - else: - config_data = dict(connector.config) - from app.config import config as app_config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and app_config.SECRET_KEY: - token_encryption = TokenEncryption(app_config.SECRET_KEY) - for key in ("token", "refresh_token", "client_secret"): - if config_data.get(key): - config_data[key] = token_encryption.decrypt_token( - config_data[key] - ) - - exp = config_data.get("expiry", "") - if exp: - exp = exp.replace("Z", "") - - creds = Credentials( - token=config_data.get("token"), - refresh_token=config_data.get("refresh_token"), - token_uri=config_data.get("token_uri"), - client_id=config_data.get("client_id"), - client_secret=config_data.get("client_secret"), - scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, + logger.info( + f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})" + ) + result = request_approval( + action_type="google_calendar_event_update", + tool_name="update_calendar_event", + params={ + "event_id": event_id, + "document_id": document_id, + "connector_id": connector_id_from_context, + "new_summary": new_summary, + "new_start_datetime": new_start_datetime, + "new_end_datetime": new_end_datetime, + "new_description": new_description, + "new_location": new_location, + "new_attendees": new_attendees, + }, + context=context, ) - update_body: dict[str, Any] = {} - if final_new_summary is not None: - update_body["summary"] = final_new_summary - if final_new_start_datetime is not None: - update_body["start"] = _build_time_body( - final_new_start_datetime, context + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The event was not updated. Do not ask again or suggest alternatives.", + } + + final_event_id = result.params.get("event_id", event_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context ) - if final_new_end_datetime is not None: - update_body["end"] = _build_time_body(final_new_end_datetime, context) - if final_new_description is not None: - update_body["description"] = final_new_description - if final_new_location is not None: - update_body["location"] = final_new_location - if final_new_attendees is not None: - update_body["attendees"] = [ - {"email": e.strip()} for e in final_new_attendees if e.strip() + final_new_summary = result.params.get("new_summary", new_summary) + final_new_start_datetime = result.params.get( + "new_start_datetime", new_start_datetime + ) + final_new_end_datetime = result.params.get( + "new_end_datetime", new_end_datetime + ) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_location = result.params.get("new_location", new_location) + final_new_attendees = result.params.get("new_attendees", new_attendees) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this event.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _calendar_types = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, ] - if not update_body: - return { - "status": "error", - "message": "No changes specified. Please provide at least one field to update.", - } - - try: - if is_composio_calendar: - from app.services.composio_service import ComposioService - - composio_params: dict[str, Any] = { - "calendar_id": "primary", - "event_id": final_event_id, - } - if final_new_summary is not None: - composio_params["summary"] = final_new_summary - if final_new_start_datetime is not None: - composio_params["start_time"] = final_new_start_datetime - if final_new_end_datetime is not None: - composio_params["end_time"] = final_new_end_datetime - if final_new_description is not None: - composio_params["description"] = final_new_description - if final_new_location is not None: - composio_params["location"] = final_new_location - if final_new_attendees is not None: - composio_params["attendees"] = [ - e.strip() for e in final_new_attendees if e.strip() - ] - if not _is_date_only( - final_new_start_datetime or final_new_end_datetime or "" - ): - composio_params["timezone"] = context.get("timezone", "UTC") - - composio_result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLECALENDAR_PATCH_EVENT", - params=composio_params, - entity_id=f"surfsense_{user_id}", + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_calendar_types), ) - if not composio_result.get("success"): - raise RuntimeError( - composio_result.get( - "error", "Unknown Composio Calendar error" - ) - ) - updated = composio_result.get("data", {}) - if isinstance(updated, dict): - updated = updated.get("data", updated) - if isinstance(updated, dict): - updated = updated.get("response_data", updated) - else: - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) + ) + connector = result.scalars().first() + if not connector: return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "Selected Google Calendar connector is invalid or has been disconnected.", } - raise - logger.info(f"Calendar event updated: event_id={final_event_id}") + actual_connector_id = connector.id - kb_message_suffix = "" - if document_id is not None: - try: - from app.services.google_calendar import GoogleCalendarKBSyncService + logger.info( + f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" + ) - kb_service = GoogleCalendarKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - event_id=final_event_id, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + is_composio_calendar = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ) + if is_composio_calendar: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } + else: + config_data = dict(connector.config) + + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and app_config.SECRET_KEY: + token_encryption = TokenEncryption(app_config.SECRET_KEY) + for key in ("token", "refresh_token", "client_secret"): + if config_data.get(key): + config_data[key] = token_encryption.decrypt_token( + config_data[key] + ) + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") + + creds = Credentials( + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." - ) - else: - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." - return { - "status": "success", - "event_id": final_event_id, - "html_link": updated.get("htmlLink"), - "message": f"Successfully updated the calendar event.{kb_message_suffix}", - } + update_body: dict[str, Any] = {} + if final_new_summary is not None: + update_body["summary"] = final_new_summary + if final_new_start_datetime is not None: + update_body["start"] = _build_time_body( + final_new_start_datetime, context + ) + if final_new_end_datetime is not None: + update_body["end"] = _build_time_body( + final_new_end_datetime, context + ) + if final_new_description is not None: + update_body["description"] = final_new_description + if final_new_location is not None: + update_body["location"] = final_new_location + if final_new_attendees is not None: + update_body["attendees"] = [ + {"email": e.strip()} for e in final_new_attendees if e.strip() + ] + + if not update_body: + return { + "status": "error", + "message": "No changes specified. Please provide at least one field to update.", + } + + try: + if is_composio_calendar: + from app.services.composio_service import ComposioService + + composio_params: dict[str, Any] = { + "calendar_id": "primary", + "event_id": final_event_id, + } + if final_new_summary is not None: + composio_params["summary"] = final_new_summary + if final_new_start_datetime is not None: + composio_params["start_time"] = final_new_start_datetime + if final_new_end_datetime is not None: + composio_params["end_time"] = final_new_end_datetime + if final_new_description is not None: + composio_params["description"] = final_new_description + if final_new_location is not None: + composio_params["location"] = final_new_location + if final_new_attendees is not None: + composio_params["attendees"] = [ + e.strip() for e in final_new_attendees if e.strip() + ] + if not _is_date_only( + final_new_start_datetime or final_new_end_datetime or "" + ): + composio_params["timezone"] = context.get("timezone", "UTC") + + composio_result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLECALENDAR_PATCH_EVENT", + params=composio_params, + entity_id=f"surfsense_{user_id}", + ) + if not composio_result.get("success"): + raise RuntimeError( + composio_result.get( + "error", "Unknown Composio Calendar error" + ) + ) + updated = composio_result.get("data", {}) + if isinstance(updated, dict): + updated = updated.get("data", updated) + if isinstance(updated, dict): + updated = updated.get("response_data", updated) + else: + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) + ) + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .patch( + calendarId="primary", + eventId=final_event_id, + body=update_body, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info(f"Calendar event updated: event_id={final_event_id}") + + kb_message_suffix = "" + if document_id is not None: + try: + from app.services.google_calendar import ( + GoogleCalendarKBSyncService, + ) + + kb_service = GoogleCalendarKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + event_id=final_event_id, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") + kb_message_suffix = " The knowledge base will be updated in the next scheduled sync." + + return { + "status": "success", + "event_id": final_event_id, + "html_link": updated.get("htmlLink"), + "message": f"Successfully updated the calendar event.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index 2becec100..66199ca67 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -23,6 +24,25 @@ def create_create_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured create_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_google_drive_file( name: str, @@ -65,7 +85,7 @@ def create_create_google_drive_file_tool( f"create_google_drive_file called: name='{name}', type='{file_type}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", @@ -78,225 +98,232 @@ def create_create_google_drive_file_tool( } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Google Drive accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - - logger.info( - f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" - ) - result = request_approval( - action_type="google_drive_file_creation", - tool_name="create_google_drive_file", - params={ - "name": name, - "file_type": file_type, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_file_type = result.params.get("file_type", file_type) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - mime_type = _MIME_MAP.get(final_file_type) - if not mime_type: - return { - "status": "error", - "message": f"Unsupported file type '{final_file_type}'.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - if final_connector_id is not None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", - } - actual_connector_id = connector.id + return {"status": "error", "message": context["error"]} - logger.info( - f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" - ) - - is_composio_drive = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ) - if is_composio_drive: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Drive connector.", - } - client = GoogleDriveClient( - session=db_session, - connector_id=actual_connector_id, - ) - try: - if is_composio_drive: - from app.services.composio_service import ComposioService - - params: dict[str, Any] = { - "name": final_name, - "mimeType": mime_type, - "fields": "id,name,webViewLink,mimeType", - } - if final_parent_folder_id: - params["parents"] = [final_parent_folder_id] - if final_content: - params["description"] = final_content[:4096] - - result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLEDRIVE_CREATE_FILE", - params=params, - entity_id=f"surfsense_{user_id}", - ) - if not result.get("success"): - raise RuntimeError( - result.get("error", "Unknown Composio Drive error") - ) - created = result.get("data", {}) - if isinstance(created, dict): - created = created.get("data", created) - if isinstance(created, dict): - created = created.get("response_data", created) - if not isinstance(created, dict): - created = {} - else: - created = await client.create_file( - name=final_name, - mime_type=mime_type, - parent_folder_id=final_parent_folder_id, - content=final_content, - ) - except HttpError as http_err: - if http_err.resp.status == 403: + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + "All Google Drive accounts have expired authentication" ) - try: - from sqlalchemy.orm.attributes import flag_modified - - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "auth_error", + "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", } - raise - logger.info( - f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.google_drive import GoogleDriveKBSyncService - - kb_service = GoogleDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=mime_type, - web_view_link=created.get("webViewLink"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, + logger.info( + f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" + ) + result = request_approval( + action_type="google_drive_file_creation", + tool_name="create_google_drive_file", + params={ + "name": name, + "file_type": file_type, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_view_link": created.get("webViewLink"), - "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not created. Do not ask again or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_file_type = result.params.get("file_type", file_type) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + mime_type = _MIME_MAP.get(final_file_type) + if not mime_type: + return { + "status": "error", + "message": f"Unsupported file type '{final_file_type}'.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.", + } + actual_connector_id = connector.id + + logger.info( + f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + client = GoogleDriveClient( + session=db_session, + connector_id=actual_connector_id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + params: dict[str, Any] = { + "name": final_name, + "mimeType": mime_type, + "fields": "id,name,webViewLink,mimeType", + } + if final_parent_folder_id: + params["parents"] = [final_parent_folder_id] + if final_content: + params["description"] = final_content[:4096] + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_CREATE_FILE", + params=params, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + created = result.get("data", {}) + if isinstance(created, dict): + created = created.get("data", created) + if isinstance(created, dict): + created = created.get("response_data", created) + if not isinstance(created, dict): + created = {} + else: + created = await client.create_file( + name=final_name, + mime_type=mime_type, + parent_folder_id=final_parent_folder_id, + content=final_content, + ) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.google_drive import GoogleDriveKBSyncService + + kb_service = GoogleDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=mime_type, + web_view_link=created.get("webViewLink"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_view_link": created.get("webViewLink"), + "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index 3c404527e..b3c9240d8 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.google_drive.client import GoogleDriveClient +from app.db import async_session_maker from app.services.google_drive import GoogleDriveToolMetadataService logger = logging.getLogger(__name__) @@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_google_drive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Google Drive connector + user_id: User ID for fetching user-specific context + + Returns: + Configured delete_google_drive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_google_drive_file( file_name: str, @@ -55,211 +75,214 @@ def create_delete_google_drive_file_tool( f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "Google Drive tool not properly configured. Please contact support.", } try: - metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_trash_context( - search_space_id, user_id, file_name - ) - - if "error" in context: - error_msg = context["error"] - if "not found" in error_msg.lower(): - logger.warning(f"File not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - logger.error(f"Failed to fetch trash context: {error_msg}") - return {"status": "error", "message": error_msg} - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Google Drive account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + metadata_service = GoogleDriveToolMetadataService(db_session) + context = await metadata_service.get_trash_context( + search_space_id, user_id, file_name ) - return { - "status": "auth_error", - "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "google_drive", - } - file = context["file"] - file_id = file["file_id"] - document_id = file.get("document_id") - connector_id_from_context = context["account"]["id"] + if "error" in context: + error_msg = context["error"] + if "not found" in error_msg.lower(): + logger.warning(f"File not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + logger.error(f"Failed to fetch trash context: {error_msg}") + return {"status": "error", "message": error_msg} - if not file_id: - return { - "status": "error", - "message": "File ID is missing from the indexed document. Please re-index the file and try again.", - } + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Google Drive account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "google_drive", + } - logger.info( - f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="google_drive_file_trash", - tool_name="delete_google_drive_file", - params={ - "file_id": file_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + file = context["file"] + file_id = file["file_id"] + document_id = file.get("document_id") + connector_id_from_context = context["account"]["id"] - if result.rejected: - return { - "status": "rejected", - "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", - } - - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this file.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - _drive_types = [ - SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type.in_(_drive_types), - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Google Drive connector is invalid or has been disconnected.", - } - - logger.info( - f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" - ) - - is_composio_drive = ( - connector.connector_type - == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ) - if is_composio_drive: - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: + if not file_id: return { "status": "error", - "message": "Composio connected account ID not found for this Drive connector.", + "message": "File ID is missing from the indexed document. Please re-index the file and try again.", } - client = GoogleDriveClient( - session=db_session, - connector_id=connector.id, - ) - try: - if is_composio_drive: - from app.services.composio_service import ComposioService - - result = await ComposioService().execute_tool( - connected_account_id=cca_id, - tool_name="GOOGLEDRIVE_TRASH_FILE", - params={"file_id": final_file_id}, - entity_id=f"surfsense_{user_id}", - ) - if not result.get("success"): - raise RuntimeError( - result.get("error", "Unknown Composio Drive error") - ) - else: - await client.trash_file(file_id=final_file_id) - except HttpError as http_err: - if http_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {http_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", - } - raise - - logger.info( - f"Google Drive file deleted (moved to trash): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file['name']}' to trash.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to trash, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + logger.info( + f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="google_drive_file_trash", + tool_name="delete_google_drive_file", + params={ + "file_id": file_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this file.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + _drive_types = [ + SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_drive_types), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Google Drive connector is invalid or has been disconnected.", + } + + logger.info( + f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" + ) + + is_composio_drive = ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR + ) + if is_composio_drive: + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Drive connector.", + } + + client = GoogleDriveClient( + session=db_session, + connector_id=connector.id, + ) + try: + if is_composio_drive: + from app.services.composio_service import ComposioService + + result = await ComposioService().execute_tool( + connected_account_id=cca_id, + tool_name="GOOGLEDRIVE_TRASH_FILE", + params={"file_id": final_file_id}, + entity_id=f"surfsense_{user_id}", + ) + if not result.get("success"): + raise RuntimeError( + result.get("error", "Unknown Composio Drive error") + ) + else: + await client.trash_file(file_id=final_file_id) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {http_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Google Drive file deleted (moved to trash): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file['name']}' to trash.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to trash, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py index 8b40dde65..0b04f1642 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,28 @@ def create_create_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the create_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Jira API + blocking. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured create_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def create_jira_issue( project_key: str, @@ -49,158 +72,167 @@ def create_create_jira_issue_tool( f"create_jira_issue called: project_key='{project_key}', summary='{summary}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - return {"status": "error", "message": context["error"]} - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected Jira accounts need re-authentication.", - "connector_type": "jira", - } - - result = request_approval( - action_type="jira_issue_creation", - tool_name="create_jira_issue", - params={ - "project_key": project_key, - "summary": summary, - "issue_type": issue_type, - "description": description, - "priority": priority, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_project_key = result.params.get("project_key", project_key) - final_summary = result.params.get("summary", summary) - final_issue_type = result.params.get("issue_type", issue_type) - final_description = result.params.get("description", description) - final_priority = result.params.get("priority", priority) - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_summary or not final_summary.strip(): - return {"status": "error", "message": "Issue summary cannot be empty."} - if not final_project_key: - return {"status": "error", "message": "A project must be selected."} - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - return {"status": "error", "message": "No Jira connector found."} - actual_connector_id = connector.id - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, - ) + + if "error" in context: + return {"status": "error", "message": context["error"]} + + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected Jira accounts need re-authentication.", + "connector_type": "jira", + } + + result = request_approval( + action_type="jira_issue_creation", + tool_name="create_jira_issue", + params={ + "project_key": project_key, + "summary": summary, + "issue_type": issue_type, + "description": description, + "priority": priority, + "connector_id": connector_id, + }, + context=context, ) - connector = result.scalars().first() - if not connector: + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_project_key = result.params.get("project_key", project_key) + final_summary = result.params.get("summary", summary) + final_issue_type = result.params.get("issue_type", issue_type) + final_description = result.params.get("description", description) + final_priority = result.params.get("priority", priority) + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_summary or not final_summary.strip(): return { "status": "error", - "message": "Selected Jira connector is invalid.", + "message": "Issue summary cannot be empty.", } + if not final_project_key: + return {"status": "error", "message": "A project must be selected."} - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=actual_connector_id - ) - jira_client = await jira_history._get_jira_client() - api_result = await asyncio.to_thread( - jira_client.create_issue, - project_key=final_project_key, - summary=final_summary, - issue_type=final_issue_type, - description=final_description, - priority=final_priority, - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - _conn = connector - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - pass - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + from sqlalchemy.future import select - issue_key = api_result.get("key", "") - issue_url = ( - f"{jira_history._base_url}/browse/{issue_key}" - if jira_history._base_url and issue_key - else "" - ) + from app.db import SearchSourceConnector, SearchSourceConnectorType - kb_message_suffix = "" - try: - from app.services.jira import JiraKBSyncService - - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=issue_key, - issue_identifier=issue_key, - issue_title=final_summary, - description=final_description, - state="To Do", - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Jira connector found.", + } + actual_connector_id = connector.id else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - return { - "status": "success", - "issue_key": issue_key, - "issue_url": issue_url, - "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", - } + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=actual_connector_id + ) + jira_client = await jira_history._get_jira_client() + api_result = await asyncio.to_thread( + jira_client.create_issue, + project_key=final_project_key, + summary=final_summary, + issue_type=final_issue_type, + description=final_description, + priority=final_priority, + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + _conn = connector + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_key = api_result.get("key", "") + issue_url = ( + f"{jira_history._base_url}/browse/{issue_key}" + if jira_history._base_url and issue_key + else "" + ) + + kb_message_suffix = "" + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=issue_key, + issue_identifier=issue_key, + issue_title=final_summary, + description=final_description, + state="To Do", + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_key": issue_key, + "issue_url": issue_url, + "message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py index 6466c80ea..c41aedad9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_delete_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the delete_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured delete_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def delete_jira_issue( issue_title_or_key: str, @@ -44,130 +65,136 @@ def create_delete_jira_issue_tool( f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_deletion_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data["document_id"] - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_deletion", - tool_name="delete_jira_issue", - params={ - "issue_key": issue_key, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_deletion_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", + } + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} + + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data["document_id"] + connector_id_from_context = context.get("account", {}).get("id") + + result = request_approval( + action_type="jira_issue_deletion", + tool_name="delete_jira_issue", + params={ + "issue_key": issue_key, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread(jira_client.delete_issue, final_issue_key) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - from app.db import Document + final_issue_key = result.params.get("issue_key", issue_key) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } - message = f"Jira issue {final_issue_key} deleted successfully." - if deleted_from_kb: - message += " Also removed from the knowledge base." + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread(jira_client.delete_issue, final_issue_key) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise - return { - "status": "success", - "issue_key": final_issue_key, - "deleted_from_kb": deleted_from_kb, - "message": message, - } + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + from app.db import Document + + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + + message = f"Jira issue {final_issue_key} deleted successfully." + if deleted_from_kb: + message += " Also removed from the knowledge base." + + return { + "status": "success", + "issue_key": final_issue_key, + "deleted_from_kb": deleted_from_kb, + "message": message, + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py index f6e586a2e..0fd7b28b3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py @@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified from app.agents.new_chat.tools.hitl import request_approval from app.connectors.jira_history import JiraHistoryConnector +from app.db import async_session_maker from app.services.jira import JiraToolMetadataService logger = logging.getLogger(__name__) @@ -19,6 +20,26 @@ def create_update_jira_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): + """Factory function to create the update_jira_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + search_space_id: Search space ID to find the Jira connector + user_id: User ID for fetching user-specific context + connector_id: Optional specific connector ID (if known) + + Returns: + Configured update_jira_issue tool + """ + del db_session # per-call session — see docstring + @tool async def update_jira_issue( issue_title_or_key: str, @@ -48,169 +69,177 @@ def create_update_jira_issue_tool( f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Jira tool not properly configured."} try: - metadata_service = JiraToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_title_or_key - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "jira", - } - if "not found" in error_msg.lower(): - return {"status": "not_found", "message": error_msg} - return {"status": "error", "message": error_msg} - - issue_data = context["issue"] - issue_key = issue_data["issue_id"] - document_id = issue_data.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - result = request_approval( - action_type="jira_issue_update", - tool_name="update_jira_issue", - params={ - "issue_key": issue_key, - "document_id": document_id, - "new_summary": new_summary, - "new_description": new_description, - "new_priority": new_priority, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_key = result.params.get("issue_key", issue_key) - final_summary = result.params.get("new_summary", new_summary) - final_description = result.params.get("new_description", new_description) - final_priority = result.params.get("new_priority", new_priority) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_document_id = result.params.get("document_id", document_id) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if not final_connector_id: - return { - "status": "error", - "message": "No connector found for this issue.", - } - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.JIRA_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = JiraToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_title_or_key ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Jira connector is invalid.", - } - fields: dict[str, Any] = {} - if final_summary: - fields["summary"] = final_summary - if final_description is not None: - fields["description"] = { - "type": "doc", - "version": 1, - "content": [ - { - "type": "paragraph", - "content": [{"type": "text", "text": final_description}], + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "jira", } - ], - } - if final_priority: - fields["priority"] = {"name": final_priority} + if "not found" in error_msg.lower(): + return {"status": "not_found", "message": error_msg} + return {"status": "error", "message": error_msg} - if not fields: - return {"status": "error", "message": "No changes specified."} + issue_data = context["issue"] + issue_key = issue_data["issue_id"] + document_id = issue_data.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") - try: - jira_history = JiraHistoryConnector( - session=db_session, connector_id=final_connector_id + result = request_approval( + action_type="jira_issue_update", + tool_name="update_jira_issue", + params={ + "issue_key": issue_key, + "document_id": document_id, + "new_summary": new_summary, + "new_description": new_description, + "new_priority": new_priority, + "connector_id": connector_id_from_context, + }, + context=context, ) - jira_client = await jira_history._get_jira_client() - await asyncio.to_thread( - jira_client.update_issue, final_issue_key, fields - ) - except Exception as api_err: - if "status code 403" in str(api_err).lower(): - try: - connector.config = {**connector.config, "auth_expired": True} - flag_modified(connector, "config") - await db_session.commit() - except Exception: - pass + + if result.rejected: return { - "status": "insufficient_permissions", - "connector_id": final_connector_id, - "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", } - raise - issue_url = ( - f"{jira_history._base_url}/browse/{final_issue_key}" - if jira_history._base_url and final_issue_key - else "" - ) + final_issue_key = result.params.get("issue_key", issue_key) + final_summary = result.params.get("new_summary", new_summary) + final_description = result.params.get( + "new_description", new_description + ) + final_priority = result.params.get("new_priority", new_priority) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_document_id = result.params.get("document_id", document_id) - kb_message_suffix = "" - if final_document_id: - try: - from app.services.jira import JiraKBSyncService + from sqlalchemy.future import select - kb_service = JiraKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_key, - user_id=user_id, - search_space_id=search_space_id, + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if not final_connector_id: + return { + "status": "error", + "message": "No connector found for this issue.", + } + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.JIRA_CONNECTOR, ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Jira connector is invalid.", + } + + fields: dict[str, Any] = {} + if final_summary: + fields["summary"] = final_summary + if final_description is not None: + fields["description"] = { + "type": "doc", + "version": 1, + "content": [ + { + "type": "paragraph", + "content": [ + {"type": "text", "text": final_description} + ], + } + ], + } + if final_priority: + fields["priority"] = {"name": final_priority} + + if not fields: + return {"status": "error", "message": "No changes specified."} + + try: + jira_history = JiraHistoryConnector( + session=db_session, connector_id=final_connector_id + ) + jira_client = await jira_history._get_jira_client() + await asyncio.to_thread( + jira_client.update_issue, final_issue_key, fields + ) + except Exception as api_err: + if "status code 403" in str(api_err).lower(): + try: + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + pass + return { + "status": "insufficient_permissions", + "connector_id": final_connector_id, + "message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + issue_url = ( + f"{jira_history._base_url}/browse/{final_issue_key}" + if jira_history._base_url and final_issue_key + else "" + ) + + kb_message_suffix = "" + if final_document_id: + try: + from app.services.jira import JiraKBSyncService + + kb_service = JiraKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_key, + user_id=user_id, + search_space_id=search_space_id, ) - else: + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = ( + " The knowledge base will be updated in the next sync." + ) + except Exception as kb_err: + logger.warning(f"KB sync after update failed: {kb_err}") kb_message_suffix = ( " The knowledge base will be updated in the next sync." ) - except Exception as kb_err: - logger.warning(f"KB sync after update failed: {kb_err}") - kb_message_suffix = ( - " The knowledge base will be updated in the next sync." - ) - return { - "status": "success", - "issue_key": final_issue_key, - "issue_url": issue_url, - "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", - } + return { + "status": "success", + "issue_key": final_issue_key, + "issue_url": issue_url, + "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py index ff254e133..f897bee7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/create_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_create_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the create_linear_issue tool. + """Factory function to create the create_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_create_linear_issue_tool( Returns: Configured create_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def create_linear_issue( @@ -65,7 +73,7 @@ def create_create_linear_issue_tool( """ logger.info(f"create_linear_issue called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -75,160 +83,170 @@ def create_create_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return {"status": "error", "message": "Issue title cannot be empty."} - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: + + if "error" in context: + logger.error( + f"Failed to fetch creation context: {context['error']}" + ) + return {"status": "error", "message": context["error"]} + + workspaces = context.get("workspaces", []) + if workspaces and all(w.get("auth_expired") for w in workspaces): + logger.warning("All Linear accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "linear", + } + + logger.info(f"Requesting approval for creating Linear issue: '{title}'") + result = request_approval( + action_type="linear_issue_creation", + tool_name="create_linear_issue", + params={ + "title": title, + "description": description, + "team_id": None, + "state_id": None, + "assignee_id": None, + "priority": None, + "label_ids": [], + "connector_id": connector_id, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_title = result.params.get("title", title) + final_description = result.params.get("description", description) + final_team_id = result.params.get("team_id") + final_state_id = result.params.get("state_id") + final_assignee_id = result.params.get("assignee_id") + final_priority = result.params.get("priority") + final_label_ids = result.params.get("label_ids") or [] + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") return { "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", + "message": "Issue title cannot be empty.", } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: + if not final_team_id: return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "A team must be selected to create an issue.", } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) + from sqlalchemy.future import select - if result.get("status") == "error": - logger.error(f"Failed to create Linear issue: {result.get('message')}") - return {"status": "error", "message": result.get("message")} + from app.db import SearchSourceConnector, SearchSourceConnectorType - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Linear connector found. Please connect Linear in your workspace settings.", + } + actual_connector_id = connector.id + logger.info(f"Found Linear connector: id={actual_connector_id}") else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={actual_connector_id}") - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } + logger.info( + f"Creating Linear issue with final params: title='{final_title}'" + ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) + result = await linear_client.create_issue( + team_id=final_team_id, + title=final_title, + description=final_description, + state_id=final_state_id, + assignee_id=final_assignee_id, + priority=final_priority, + label_ids=final_label_ids if final_label_ids else None, + ) + + if result.get("status") == "error": + logger.error( + f"Failed to create Linear issue: {result.get('message')}" + ) + return {"status": "error", "message": result.get("message")} + + logger.info( + f"Linear issue created: {result.get('identifier')} - {result.get('title')}" + ) + + kb_message_suffix = "" + try: + from app.services.linear import LinearKBSyncService + + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + issue_id=result.get("id"), + issue_identifier=result.get("identifier", ""), + issue_title=result.get("title", final_title), + issue_url=result.get("url"), + description=final_description, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "issue_id": result.get("id"), + "identifier": result.get("identifier"), + "url": result.get("url"), + "message": (result.get("message", "") + kb_message_suffix), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py index 29ef0cdf2..c5039a8eb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/delete_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_delete_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the delete_linear_issue tool. + """Factory function to create the delete_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for finding the correct Linear connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_linear_issue_tool( Returns: Configured delete_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def delete_linear_issue( @@ -73,7 +81,7 @@ def create_delete_linear_issue_tool( f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -83,149 +91,152 @@ def create_delete_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, issue_ref ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for delete context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch delete context: {error_msg}") + return {"status": "error", "message": error_msg} + + issue_id = context["issue"]["id"] + issue_identifier = context["issue"].get("identifier", "") + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") + + logger.info( + f"Requesting approval for deleting Linear issue: '{issue_ref}' " + f"(id={issue_id}, delete_from_kb={delete_from_kb})" + ) + result = request_approval( + action_type="linear_issue_deletion", + tool_name="delete_linear_issue", + params={ + "issue_id": issue_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Linear issue deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Linear issue with final params: issue_id={final_issue_id}, " + f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + actual_connector_id = connector.id + logger.info(f"Validated Linear connector: id={actual_connector_id}") + else: + logger.error("No connector found for this issue") return { "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", + "message": "No connector found for this issue.", } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) + linear_client = LinearConnector( + session=db_session, connector_id=actual_connector_id + ) - result = await linear_client.archive_issue(issue_id=final_issue_id) + result = await linear_client.archive_issue(issue_id=final_issue_id) - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info( + f"archive_issue result: {result.get('status')} - {result.get('message', '')}" + ) - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from app.db import Document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if issue_identifier: + result["message"] = ( + f"Issue {issue_identifier} archived successfully." + ) + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} Also removed from the knowledge base." + ) - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py index f35d0dddd..d610ce2b7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py +++ b/surfsense_backend/app/agents/new_chat/tools/linear/update_issue.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.linear_connector import LinearAPIError, LinearConnector +from app.db import async_session_maker from app.services.linear import LinearKBSyncService, LinearToolMetadataService logger = logging.getLogger(__name__) @@ -17,11 +18,17 @@ def create_update_linear_issue_tool( user_id: str | None = None, connector_id: int | None = None, ): - """ - Factory function to create the update_linear_issue tool. + """Factory function to create the update_linear_issue tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Args: - db_session: Database session for accessing the Linear connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Linear connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_linear_issue_tool( Returns: Configured update_linear_issue tool """ + del db_session # per-call session — see docstring @tool async def update_linear_issue( @@ -86,7 +94,7 @@ def create_update_linear_issue_tool( """ logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Linear tool not properly configured - missing required parameters" ) @@ -96,176 +104,177 @@ def create_update_linear_issue_tool( } try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for update context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - team = context.get("team", {}) - new_state_id = _resolve_state(team, new_state_name) - new_assignee_id = _resolve_assignee(team, new_assignee_email) - new_label_ids = _resolve_labels(team, new_label_names) - - logger.info( - f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" - ) - result = request_approval( - action_type="linear_issue_update", - tool_name="update_linear_issue", - params={ - "issue_id": issue_id, - "document_id": document_id, - "new_title": new_title, - "new_description": new_description, - "new_state_id": new_state_id, - "new_assignee_id": new_assignee_id, - "new_priority": new_priority, - "new_label_ids": new_label_ids, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_document_id = result.params.get("document_id", document_id) - final_new_title = result.params.get("new_title", new_title) - final_new_description = result.params.get( - "new_description", new_description - ) - final_new_state_id = result.params.get("new_state_id", new_state_id) - final_new_assignee_id = result.params.get( - "new_assignee_id", new_assignee_id - ) - final_new_priority = result.params.get("new_priority", new_priority) - final_new_label_ids: list[str] | None = result.params.get( - "new_label_ids", new_label_ids - ) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - if not final_connector_id: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, + async with async_session_maker() as db_session: + metadata_service = LinearToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, issue_ref ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={final_connector_id}") - logger.info( - f"Updating Linear issue with final params: issue_id={final_issue_id}" - ) - linear_client = LinearConnector( - session=db_session, connector_id=final_connector_id - ) - updated_issue = await linear_client.update_issue( - issue_id=final_issue_id, - title=final_new_title, - description=final_new_description, - state_id=final_new_state_id, - assignee_id=final_new_assignee_id, - priority=final_new_priority, - label_ids=final_new_label_ids, - ) + if "error" in context: + error_msg = context["error"] + if context.get("auth_expired"): + logger.warning(f"Auth expired for update context: {error_msg}") + return { + "status": "auth_error", + "message": error_msg, + "connector_id": context.get("connector_id"), + "connector_type": "linear", + } + if "not found" in error_msg.lower(): + logger.warning(f"Issue not found: {error_msg}") + return {"status": "not_found", "message": error_msg} + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return {"status": "error", "message": error_msg} - if updated_issue.get("status") == "error": - logger.error( - f"Failed to update Linear issue: {updated_issue.get('message')}" - ) - return { - "status": "error", - "message": updated_issue.get("message"), - } + issue_id = context["issue"]["id"] + document_id = context["issue"]["document_id"] + connector_id_from_context = context.get("workspace", {}).get("id") - logger.info( - f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" - ) + team = context.get("team", {}) + new_state_id = _resolve_state(team, new_state_name) + new_assignee_id = _resolve_assignee(team, new_assignee_email) + new_label_ids = _resolve_labels(team, new_label_names) - if final_document_id is not None: logger.info( - f"Updating knowledge base for document {final_document_id}..." + f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})" ) - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=final_document_id, - issue_id=final_issue_id, - user_id=user_id, - search_space_id=search_space_id, + result = request_approval( + action_type="linear_issue_update", + tool_name="update_linear_issue", + params={ + "issue_id": issue_id, + "document_id": document_id, + "new_title": new_title, + "new_description": new_description, + "new_state_id": new_state_id, + "new_assignee_id": new_assignee_id, + "new_priority": new_priority, + "new_label_ids": new_label_ids, + "connector_id": connector_id_from_context, + }, + context=context, ) - if kb_result["status"] == "success": - logger.info( - f"Knowledge base successfully updated for issue {final_issue_id}" - ) - kb_message = " Your knowledge base has also been updated." - elif kb_result["status"] == "not_indexed": - kb_message = " This issue will be added to your knowledge base in the next scheduled sync." - else: - logger.warning( - f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" - ) - kb_message = " Your knowledge base will be updated in the next scheduled sync." - else: - kb_message = "" - identifier = updated_issue.get("identifier") - default_msg = f"Issue {identifier} updated successfully." - return { - "status": "success", - "identifier": identifier, - "url": updated_issue.get("url"), - "message": f"{updated_issue.get('message', default_msg)}{kb_message}", - } + if result.rejected: + logger.info("Linear issue update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_issue_id = result.params.get("issue_id", issue_id) + final_document_id = result.params.get("document_id", document_id) + final_new_title = result.params.get("new_title", new_title) + final_new_description = result.params.get( + "new_description", new_description + ) + final_new_state_id = result.params.get("new_state_id", new_state_id) + final_new_assignee_id = result.params.get( + "new_assignee_id", new_assignee_id + ) + final_new_priority = result.params.get("new_priority", new_priority) + final_new_label_ids: list[str] | None = result.params.get( + "new_label_ids", new_label_ids + ) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + if not final_connector_id: + logger.error("No connector found for this issue") + return { + "status": "error", + "message": "No connector found for this issue.", + } + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Linear connector is invalid or has been disconnected.", + } + logger.info(f"Validated Linear connector: id={final_connector_id}") + + logger.info( + f"Updating Linear issue with final params: issue_id={final_issue_id}" + ) + linear_client = LinearConnector( + session=db_session, connector_id=final_connector_id + ) + updated_issue = await linear_client.update_issue( + issue_id=final_issue_id, + title=final_new_title, + description=final_new_description, + state_id=final_new_state_id, + assignee_id=final_new_assignee_id, + priority=final_new_priority, + label_ids=final_new_label_ids, + ) + + if updated_issue.get("status") == "error": + logger.error( + f"Failed to update Linear issue: {updated_issue.get('message')}" + ) + return { + "status": "error", + "message": updated_issue.get("message"), + } + + logger.info( + f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" + ) + + if final_document_id is not None: + logger.info( + f"Updating knowledge base for document {final_document_id}..." + ) + kb_service = LinearKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=final_document_id, + issue_id=final_issue_id, + user_id=user_id, + search_space_id=search_space_id, + ) + if kb_result["status"] == "success": + logger.info( + f"Knowledge base successfully updated for issue {final_issue_id}" + ) + kb_message = " Your knowledge base has also been updated." + elif kb_result["status"] == "not_indexed": + kb_message = " This issue will be added to your knowledge base in the next scheduled sync." + else: + logger.warning( + f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}" + ) + kb_message = " Your knowledge base will be updated in the next scheduled sync." + else: + kb_message = "" + + identifier = updated_issue.get("identifier") + default_msg = f"Issue {identifier} updated successfully." + return { + "status": "success", + "identifier": identifier, + "url": updated_issue.get("url"), + "message": f"{updated_issue.get('message', default_msg)}{kb_message}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py index 0a24a988f..65c177d7a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers @@ -17,6 +18,23 @@ def create_create_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def create_luma_event( name: str, @@ -40,83 +58,86 @@ def create_create_luma_event_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - result = request_approval( - action_type="luma_create_event", - tool_name="create_luma_event", - params={ - "name": name, - "start_at": start_at, - "end_at": end_at, - "description": description, - "timezone": timezone, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Event was not created.", - } - - final_name = result.params.get("name", name) - final_start = result.params.get("start_at", start_at) - final_end = result.params.get("end_at", end_at) - final_desc = result.params.get("description", description) - final_tz = result.params.get("timezone", timezone) - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - body: dict[str, Any] = { - "name": final_name, - "start_at": final_start, - "end_at": final_end, - "timezone": final_tz, - } - if final_desc: - body["description_md"] = final_desc - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{LUMA_API}/event/create", - headers=headers, - json=body, + result = request_approval( + action_type="luma_create_event", + tool_name="create_luma_event", + params={ + "name": name, + "start_at": start_at, + "end_at": end_at, + "description": description, + "timezone": timezone, + }, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Luma Plus subscription required to create events via API.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Event was not created.", + } - data = resp.json() - event_id = data.get("api_id") or data.get("event", {}).get("api_id") + final_name = result.params.get("name", name) + final_start = result.params.get("start_at", start_at) + final_end = result.params.get("end_at", end_at) + final_desc = result.params.get("description", description) + final_tz = result.params.get("timezone", timezone) - return { - "status": "success", - "event_id": event_id, - "message": f"Event '{final_name}' created on Luma.", - } + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + body: dict[str, Any] = { + "name": final_name, + "start_at": final_start, + "end_at": final_end, + "timezone": final_tz, + } + if final_desc: + body["description_md"] = final_desc + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{LUMA_API}/event/create", + headers=headers, + json=body, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Luma Plus subscription required to create events via API.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", + } + + data = resp.json() + event_id = data.get("api_id") or data.get("event", {}).get("api_id") + + return { + "status": "success", + "event_id": event_id, + "message": f"Event '{final_name}' created on Luma.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py index aec5ad220..6885c2049 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_luma_events_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_luma_events tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_luma_events tool + """ + del db_session # per-call session — see docstring + @tool async def list_luma_events( max_results: int = 25, @@ -28,77 +47,80 @@ def create_list_luma_events_tool( Dictionary with status and a list of events including event_id, name, start_at, end_at, location, url. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} max_results = min(max_results, 50) try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - api_key = get_api_key(connector) - headers = luma_headers(api_key) + api_key = get_api_key(connector) + headers = luma_headers(api_key) - all_entries: list[dict] = [] - cursor = None + all_entries: list[dict] = [] + cursor = None - async with httpx.AsyncClient(timeout=20.0) as client: - while len(all_entries) < max_results: - params: dict[str, Any] = { - "limit": min(100, max_results - len(all_entries)) - } - if cursor: - params["cursor"] = cursor + async with httpx.AsyncClient(timeout=20.0) as client: + while len(all_entries) < max_results: + params: dict[str, Any] = { + "limit": min(100, max_results - len(all_entries)) + } + if cursor: + params["cursor"] = cursor - resp = await client.get( - f"{LUMA_API}/calendar/list-events", - headers=headers, - params=params, + resp = await client.get( + f"{LUMA_API}/calendar/list-events", + headers=headers, + params=params, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + entries = data.get("entries", []) + if not entries: + break + all_entries.extend(entries) + + next_cursor = data.get("next_cursor") + if not next_cursor: + break + cursor = next_cursor + + events = [] + for entry in all_entries[:max_results]: + ev = entry.get("event", {}) + geo = ev.get("geo_info", {}) + events.append( + { + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + } ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", - } - - data = resp.json() - entries = data.get("entries", []) - if not entries: - break - all_entries.extend(entries) - - next_cursor = data.get("next_cursor") - if not next_cursor: - break - cursor = next_cursor - - events = [] - for entry in all_entries[:max_results]: - ev = entry.get("event", {}) - geo = ev.get("geo_info", {}) - events.append( - { - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - } - ) - - return {"status": "success", "events": events, "total": len(events)} + return {"status": "success", "events": events, "total": len(events)} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py index b37a9d617..a8484e9c0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_luma_event_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_luma_event tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_luma_event tool + """ + del db_session # per-call session — see docstring + @tool async def read_luma_event(event_id: str) -> dict[str, Any]: """Read detailed information about a specific Luma event. @@ -26,60 +45,63 @@ def create_read_luma_event_tool( Dictionary with status and full event details including description, attendees count, meeting URL. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Luma tool not properly configured."} try: - connector = await get_luma_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Luma connector found."} - - api_key = get_api_key(connector) - headers = luma_headers(api_key) - - async with httpx.AsyncClient(timeout=15.0) as client: - resp = await client.get( - f"{LUMA_API}/events/{event_id}", - headers=headers, + async with async_session_maker() as db_session: + connector = await get_luma_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Luma connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Luma API key is invalid.", - "connector_type": "luma", - } - if resp.status_code == 404: - return { - "status": "not_found", - "message": f"Event '{event_id}' not found.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Luma API error: {resp.status_code}", + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + f"{LUMA_API}/events/{event_id}", + headers=headers, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } + if resp.status_code == 404: + return { + "status": "not_found", + "message": f"Event '{event_id}' not found.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } + + data = resp.json() + ev = data.get("event", data) + geo = ev.get("geo_info", {}) + + event_detail = { + "event_id": event_id, + "name": ev.get("name", ""), + "description": ev.get("description", ""), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location_name": geo.get("name", ""), + "address": geo.get("address", ""), + "url": ev.get("url", ""), + "meeting_url": ev.get("meeting_url", ""), + "visibility": ev.get("visibility", ""), + "cover_url": ev.get("cover_url", ""), } - data = resp.json() - ev = data.get("event", data) - geo = ev.get("geo_info", {}) - - event_detail = { - "event_id": event_id, - "name": ev.get("name", ""), - "description": ev.get("description", ""), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location_name": geo.get("name", ""), - "address": geo.get("address", ""), - "url": ev.get("url", ""), - "meeting_url": ev.get("meeting_url", ""), - "visibility": ev.get("visibility", ""), - "cover_url": ev.get("cover_url", ""), - } - - return {"status": "success", "event": event_detail} + return {"status": "success", "event": event_detail} except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py index 6efffe960..6ec95e9f0 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/create_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,17 @@ def create_create_notion_page_tool( """ Factory function to create the create_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker`. This is critical for the compiled-agent + cache: the compiled graph (and therefore this closure) is reused + across HTTP requests, so capturing a per-request session here would + surface stale/closed sessions on cache hits. Per-call sessions also + keep the request's outer transaction free of long-running Notion API + blocking. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +39,7 @@ def create_create_notion_page_tool( Returns: Configured create_notion_page tool """ + del db_session # per-call session — see docstring @tool async def create_notion_page( @@ -67,7 +78,7 @@ def create_create_notion_page_tool( """ logger.info(f"create_notion_page called: title='{title}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -77,154 +88,157 @@ def create_create_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return { - "status": "error", - "message": context["error"], - } - - accounts = context.get("accounts", []) - if accounts and all(a.get("auth_expired") for a in accounts): - logger.warning("All Notion accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "notion", - } - - logger.info(f"Requesting approval for creating Notion page: '{title}'") - result = request_approval( - action_type="notion_page_creation", - tool_name="create_notion_page", - params={ - "title": title, - "content": content, - "parent_page_id": None, - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_content = result.params.get("content", content) - final_parent_page_id = result.params.get("parent_page_id") - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return { - "status": "error", - "message": "Page title cannot be empty. Please provide a valid title.", - } - - logger.info( - f"Creating Notion page with final params: title='{final_title}'" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_creation_context( + search_space_id, user_id ) - connector = result.scalars().first() - if not connector: - logger.warning( - f"No Notion connector found for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "No Notion connector found. Please connect Notion in your workspace settings.", - } - - actual_connector_id = connector.id - logger.info(f"Found Notion connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: + if "error" in context: logger.error( - f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + f"Failed to fetch creation context: {context['error']}" ) return { "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + "message": context["error"], } - logger.info(f"Validated Notion connector: id={actual_connector_id}") - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) + accounts = context.get("accounts", []) + if accounts and all(a.get("auth_expired") for a in accounts): + logger.warning("All Notion accounts have expired authentication") + return { + "status": "auth_error", + "message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "notion", + } - result = await notion_connector.create_page( - title=final_title, - content=final_content, - parent_page_id=final_parent_page_id, - ) - logger.info( - f"create_page result: {result.get('status')} - {result.get('message', '')}" - ) + logger.info(f"Requesting approval for creating Notion page: '{title}'") + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={ + "title": title, + "content": content, + "parent_page_id": None, + "connector_id": connector_id, + }, + context=context, + ) - if result.get("status") == "success": - kb_message_suffix = "" - try: - from app.services.notion import NotionKBSyncService + if result.rejected: + logger.info("Notion page creation rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - page_id=result.get("page_id"), - page_title=result.get("title", final_title), - page_url=result.get("url"), - content=final_content, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = ( - " Your knowledge base has also been updated." + final_title = result.params.get("title", title) + final_content = result.params.get("content", content) + final_parent_page_id = result.params.get("parent_page_id") + final_connector_id = result.params.get("connector_id", connector_id) + + if not final_title or not final_title.strip(): + logger.error("Title is empty or contains only whitespace") + return { + "status": "error", + "message": "Page title cannot be empty. Please provide a valid title.", + } + + logger.info( + f"Creating Notion page with final params: title='{final_title}'" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + actual_connector_id = final_connector_id + if actual_connector_id is None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, ) - else: + ) + connector = result.scalars().first() + + if not connector: + logger.warning( + f"No Notion connector found for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "No Notion connector found. Please connect Notion in your workspace settings.", + } + + actual_connector_id = connector.id + logger.info(f"Found Notion connector: id={actual_connector_id}") + else: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == actual_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + logger.info(f"Validated Notion connector: id={actual_connector_id}") + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.create_page( + title=final_title, + content=final_content, + parent_page_id=final_parent_page_id, + ) + logger.info( + f"create_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success": + kb_message_suffix = "" + try: + from app.services.notion import NotionKBSyncService + + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + page_id=result.get("page_id"), + page_title=result.get("title", final_title), + page_url=result.get("url"), + content=final_content, + connector_id=actual_connector_id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." - result["message"] = result.get("message", "") + kb_message_suffix + result["message"] = result.get("message", "") + kb_message_suffix - return result + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py index 07f7583d2..7b85da4c2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/delete_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion.tool_metadata_service import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_delete_notion_page_tool( """ Factory function to create the delete_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for finding the correct Notion connector connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_delete_notion_page_tool( Returns: Configured delete_notion_page tool """ + del db_session # per-call session — see docstring @tool async def delete_notion_page( @@ -63,7 +71,7 @@ def create_delete_notion_page_tool( f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -73,164 +81,167 @@ def create_delete_notion_page_tool( } try: - # Get page context (page_id, account, title) from indexed data - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), + async with async_session_maker() as db_session: + # Get page context (page_id, account, title) from indexed data + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_delete_context( + search_space_id, user_id, page_title ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - page_id = context.get("page_id") - connector_id_from_context = account.get("id") - document_id = context.get("document_id") - - logger.info( - f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" - ) - - result = request_approval( - action_type="notion_page_deletion", - tool_name="delete_notion_page", - params={ - "page_id": page_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - # Validate the connector - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - # Create connector instance - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - # Delete the page from Notion - result = await notion_connector.delete_page(page_id=final_page_id) - logger.info( - f"delete_page result: {result.get('status')} - {result.get('message', '')}" - ) - - # If deletion was successful and user wants to delete from KB - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from sqlalchemy.future import select - - from app.db import Document - - # Get the document - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" - ) + logger.error(f"Failed to fetch delete context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } - # Update result with KB deletion status - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} (also removed from knowledge base)" + account = context.get("account", {}) + if account.get("auth_expired"): + logger.warning( + "Notion account %s has expired authentication", + account.get("id"), ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } - return result + page_id = context.get("page_id") + connector_id_from_context = account.get("id") + document_id = context.get("document_id") + + logger.info( + f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" + ) + + result = request_approval( + action_type="notion_page_deletion", + tool_name="delete_notion_page", + params={ + "page_id": page_id, + "connector_id": connector_id_from_context, + "delete_from_kb": delete_from_kb, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page deletion rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + logger.info( + f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + # Validate the connector + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + # Create connector instance + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + # Delete the page from Notion + result = await notion_connector.delete_page(page_id=final_page_id) + logger.info( + f"delete_page result: {result.get('status')} - {result.get('message', '')}" + ) + + # If deletion was successful and user wants to delete from KB + deleted_from_kb = False + if ( + result.get("status") == "success" + and final_delete_from_kb + and document_id + ): + try: + from sqlalchemy.future import select + + from app.db import Document + + # Get the document + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + document = doc_result.scalars().first() + + if document: + await db_session.delete(document) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + result["warning"] = ( + f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}" + ) + + # Update result with KB deletion status + if result.get("status") == "success": + result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + result["message"] = ( + f"{result.get('message', '')} (also removed from knowledge base)" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py index 85c08177c..df757476a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py +++ b/surfsense_backend/app/agents/new_chat/tools/notion/update_page.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector +from app.db import async_session_maker from app.services.notion import NotionToolMetadataService logger = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def create_update_notion_page_tool( """ Factory function to create the update_notion_page tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache (see + ``create_create_notion_page_tool`` for the full rationale). + Args: - db_session: Database session for accessing Notion connector + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. search_space_id: Search space ID to find the Notion connector user_id: User ID for fetching user-specific context connector_id: Optional specific connector ID (if known) @@ -29,6 +36,7 @@ def create_update_notion_page_tool( Returns: Configured update_notion_page tool """ + del db_session # per-call session — see docstring @tool async def update_notion_page( @@ -71,7 +79,7 @@ def create_update_notion_page_tool( f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) @@ -88,152 +96,155 @@ def create_update_notion_page_tool( } try: - metadata_service = NotionToolMetadataService(db_session) - context = await metadata_service.get_update_context( - search_space_id, user_id, page_title - ) - - if "error" in context: - error_msg = context["error"] - # Check if it's a "not found" error (softer handling for LLM) - if "not found" in error_msg.lower(): - logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } - else: - logger.error(f"Failed to fetch update context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } - - account = context.get("account", {}) - if account.get("auth_expired"): - logger.warning( - "Notion account %s has expired authentication", - account.get("id"), - ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } - - page_id = context.get("page_id") - document_id = context.get("document_id") - connector_id_from_context = context.get("account", {}).get("id") - - logger.info( - f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" - ) - result = request_approval( - action_type="notion_page_update", - tool_name="update_notion_page", - params={ - "page_id": page_id, - "content": content, - "connector_id": connector_id_from_context, - }, - context=context, - ) - - if result.rejected: - logger.info("Notion page update rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_page_id = result.params.get("page_id", page_id) - final_content = result.params.get("content", content) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - - logger.info( - f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.NOTION_CONNECTOR, - ) - ) - connector = result.scalars().first() - - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } - actual_connector_id = connector.id - logger.info(f"Validated Notion connector: id={actual_connector_id}") - else: - logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } - - notion_connector = NotionHistoryConnector( - session=db_session, - connector_id=actual_connector_id, - ) - - result = await notion_connector.update_page( - page_id=final_page_id, - content=final_content, - ) - logger.info( - f"update_page result: {result.get('status')} - {result.get('message', '')}" - ) - - if result.get("status") == "success" and document_id is not None: - from app.services.notion import NotionKBSyncService - - logger.info(f"Updating knowledge base for document {document_id}...") - kb_service = NotionKBSyncService(db_session) - kb_result = await kb_service.sync_after_update( - document_id=document_id, - appended_content=final_content, - user_id=user_id, - search_space_id=search_space_id, - appended_block_ids=result.get("appended_block_ids"), + async with async_session_maker() as db_session: + metadata_service = NotionToolMetadataService(db_session) + context = await metadata_service.get_update_context( + search_space_id, user_id, page_title ) - if kb_result["status"] == "success": - result["message"] = ( - f"{result['message']}. Your knowledge base has also been updated." - ) - logger.info( - f"Knowledge base successfully updated for page {final_page_id}" - ) - elif kb_result["status"] == "not_indexed": - result["message"] = ( - f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." - ) - else: - result["message"] = ( - f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." - ) + if "error" in context: + error_msg = context["error"] + # Check if it's a "not found" error (softer handling for LLM) + if "not found" in error_msg.lower(): + logger.warning(f"Page not found: {error_msg}") + return { + "status": "not_found", + "message": error_msg, + } + else: + logger.error(f"Failed to fetch update context: {error_msg}") + return { + "status": "error", + "message": error_msg, + } + + account = context.get("account", {}) + if account.get("auth_expired"): logger.warning( - f"KB update failed for page {final_page_id}: {kb_result['message']}" + "Notion account %s has expired authentication", + account.get("id"), + ) + return { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + } + + page_id = context.get("page_id") + document_id = context.get("document_id") + connector_id_from_context = context.get("account", {}).get("id") + + logger.info( + f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})" + ) + result = request_approval( + action_type="notion_page_update", + tool_name="update_notion_page", + params={ + "page_id": page_id, + "content": content, + "connector_id": connector_id_from_context, + }, + context=context, + ) + + if result.rejected: + logger.info("Notion page update rejected by user") + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_page_id = result.params.get("page_id", page_id) + final_content = result.params.get("content", content) + final_connector_id = result.params.get( + "connector_id", connector_id_from_context + ) + + logger.info( + f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}" + ) + + from sqlalchemy.future import select + + from app.db import SearchSourceConnector, SearchSourceConnectorType + + if final_connector_id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + logger.error( + f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" + ) + return { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + } + actual_connector_id = connector.id + logger.info(f"Validated Notion connector: id={actual_connector_id}") + else: + logger.error("No connector found for this page") + return { + "status": "error", + "message": "No connector found for this page.", + } + + notion_connector = NotionHistoryConnector( + session=db_session, + connector_id=actual_connector_id, + ) + + result = await notion_connector.update_page( + page_id=final_page_id, + content=final_content, + ) + logger.info( + f"update_page result: {result.get('status')} - {result.get('message', '')}" + ) + + if result.get("status") == "success" and document_id is not None: + from app.services.notion import NotionKBSyncService + + logger.info( + f"Updating knowledge base for document {document_id}..." + ) + kb_service = NotionKBSyncService(db_session) + kb_result = await kb_service.sync_after_update( + document_id=document_id, + appended_content=final_content, + user_id=user_id, + search_space_id=search_space_id, + appended_block_ids=result.get("appended_block_ids"), ) - return result + if kb_result["status"] == "success": + result["message"] = ( + f"{result['message']}. Your knowledge base has also been updated." + ) + logger.info( + f"Knowledge base successfully updated for page {final_page_id}" + ) + elif kb_result["status"] == "not_indexed": + result["message"] = ( + f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." + ) + else: + result["message"] = ( + f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." + ) + logger.warning( + f"KB update failed for page {final_page_id}: {kb_result['message']}" + ) + + return result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index 21272e01d..5f199a41b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -10,7 +10,7 @@ from sqlalchemy.future import select from app.agents.new_chat.tools.hitl import request_approval from app.connectors.onedrive.client import OneDriveClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker logger = logging.getLogger(__name__) @@ -48,6 +48,23 @@ def create_create_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the create_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured create_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def create_onedrive_file( name: str, @@ -70,173 +87,178 @@ def create_create_onedrive_file_tool( """ logger.info(f"create_onedrive_file called: name='{name}'") - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, - ) - ) - connectors = result.scalars().all() - - if not connectors: - return { - "status": "error", - "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", - } - - accounts = [] - for c in connectors: - cfg = c.config or {} - accounts.append( - { - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - } - ) - - if all(a.get("auth_expired") for a in accounts): - return { - "status": "auth_error", - "message": "All connected OneDrive accounts need re-authentication.", - "connector_type": "onedrive", - } - - parent_folders: dict[int, list[dict[str, str]]] = {} - for acc in accounts: - cid = acc["id"] - if acc.get("auth_expired"): - parent_folders[cid] = [] - continue - try: - client = OneDriveClient(session=db_session, connector_id=cid) - items, err = await client.list_children("root") - if err: - logger.warning( - "Failed to list folders for connector %s: %s", cid, err - ) - parent_folders[cid] = [] - else: - parent_folders[cid] = [ - {"folder_id": item["id"], "name": item["name"]} - for item in items - if item.get("folder") is not None - and item.get("id") - and item.get("name") - ] - except Exception: - logger.warning( - "Error fetching folders for connector %s", cid, exc_info=True - ) - parent_folders[cid] = [] - - context: dict[str, Any] = { - "accounts": accounts, - "parent_folders": parent_folders, - } - - result = request_approval( - action_type="onedrive_file_creation", - tool_name="create_onedrive_file", - params={ - "name": name, - "content": content, - "connector_id": None, - "parent_folder_id": None, - }, - context=context, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_name = result.params.get("name", name) - final_content = result.params.get("content", content) - final_connector_id = result.params.get("connector_id") - final_parent_folder_id = result.params.get("parent_folder_id") - - if not final_name or not final_name.strip(): - return {"status": "error", "message": "File name cannot be empty."} - - final_name = _ensure_docx_extension(final_name) - - if final_connector_id is not None: + async with async_session_maker() as db_session: result = await db_session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) - connector = result.scalars().first() - else: - connector = connectors[0] + connectors = result.scalars().all() - if not connector: - return { - "status": "error", - "message": "Selected OneDrive connector is invalid.", + if not connectors: + return { + "status": "error", + "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.", + } + + accounts = [] + for c in connectors: + cfg = c.config or {} + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) + + if all(a.get("auth_expired") for a in accounts): + return { + "status": "auth_error", + "message": "All connected OneDrive accounts need re-authentication.", + "connector_type": "onedrive", + } + + parent_folders: dict[int, list[dict[str, str]]] = {} + for acc in accounts: + cid = acc["id"] + if acc.get("auth_expired"): + parent_folders[cid] = [] + continue + try: + client = OneDriveClient(session=db_session, connector_id=cid) + items, err = await client.list_children("root") + if err: + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) + parent_folders[cid] = [] + else: + parent_folders[cid] = [ + {"folder_id": item["id"], "name": item["name"]} + for item in items + if item.get("folder") is not None + and item.get("id") + and item.get("name") + ] + except Exception: + logger.warning( + "Error fetching folders for connector %s", + cid, + exc_info=True, + ) + parent_folders[cid] = [] + + context: dict[str, Any] = { + "accounts": accounts, + "parent_folders": parent_folders, } - docx_bytes = _markdown_to_docx(final_content or "") - - client = OneDriveClient(session=db_session, connector_id=connector.id) - created = await client.create_file( - name=final_name, - parent_id=final_parent_folder_id, - content=docx_bytes, - mime_type=DOCX_MIME, - ) - - logger.info( - f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" - ) - - kb_message_suffix = "" - try: - from app.services.onedrive import OneDriveKBSyncService - - kb_service = OneDriveKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - file_id=created.get("id"), - file_name=created.get("name", final_name), - mime_type=DOCX_MIME, - web_url=created.get("webUrl"), - content=final_content, - connector_id=connector.id, - search_space_id=search_space_id, - user_id=user_id, + result = request_approval( + action_type="onedrive_file_creation", + tool_name="create_onedrive_file", + params={ + "name": name, + "content": content, + "connector_id": None, + "parent_folder_id": None, + }, + context=context, ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "file_id": created.get("id"), - "name": created.get("name"), - "web_url": created.get("webUrl"), - "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_name = result.params.get("name", name) + final_content = result.params.get("content", content) + final_connector_id = result.params.get("connector_id") + final_parent_folder_id = result.params.get("parent_folder_id") + + if not final_name or not final_name.strip(): + return {"status": "error", "message": "File name cannot be empty."} + + final_name = _ensure_docx_extension(final_name) + + if final_connector_id is not None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + connector = result.scalars().first() + else: + connector = connectors[0] + + if not connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid.", + } + + docx_bytes = _markdown_to_docx(final_content or "") + + client = OneDriveClient(session=db_session, connector_id=connector.id) + created = await client.create_file( + name=final_name, + parent_id=final_parent_folder_id, + content=docx_bytes, + mime_type=DOCX_MIME, + ) + + logger.info( + f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" + ) + + kb_message_suffix = "" + try: + from app.services.onedrive import OneDriveKBSyncService + + kb_service = OneDriveKBSyncService(db_session) + kb_result = await kb_service.sync_after_create( + file_id=created.get("id"), + file_name=created.get("name", final_name), + mime_type=DOCX_MIME, + web_url=created.get("webUrl"), + content=final_content, + connector_id=connector.id, + search_space_id=search_space_id, + user_id=user_id, + ) + if kb_result["status"] == "success": + kb_message_suffix = ( + " Your knowledge base has also been updated." + ) + else: + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + except Exception as kb_err: + logger.warning(f"KB sync after create failed: {kb_err}") + kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." + + return { + "status": "success", + "file_id": created.get("id"), + "name": created.get("name"), + "web_url": created.get("webUrl"), + "message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index a7f13b5df..4857ea988 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -13,6 +13,7 @@ from app.db import ( DocumentType, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) logger = logging.getLogger(__name__) @@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the delete_onedrive_file tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured delete_onedrive_file tool + """ + del db_session # per-call session — see docstring + @tool async def delete_onedrive_file( file_name: str, @@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool( f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" ) - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return { "status": "error", "message": "OneDrive tool not properly configured.", } try: - doc_result = await db_session.execute( - select(Document) - .join( - SearchSourceConnector, - Document.connector_id == SearchSourceConnector.id, - ) - .filter( - and_( - Document.search_space_id == search_space_id, - Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower(Document.title) == func.lower(file_name), - SearchSourceConnector.user_id == user_id, - ) - ) - .order_by(Document.updated_at.desc().nullslast()) - .limit(1) - ) - document = doc_result.scalars().first() - - if not document: + async with async_session_maker() as db_session: doc_result = await db_session.execute( select(Document) .join( @@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool( and_( Document.search_space_id == search_space_id, Document.document_type == DocumentType.ONEDRIVE_FILE, - func.lower( - cast( - Document.document_metadata["onedrive_file_name"], - String, - ) - ) - == func.lower(file_name), + func.lower(Document.title) == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool( ) document = doc_result.scalars().first() - if not document: - return { - "status": "not_found", - "message": ( - f"File '{file_name}' not found in your indexed OneDrive files. " - "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " - "or (3) the file name is different." - ), - } - - if not document.connector_id: - return { - "status": "error", - "message": "Document has no associated connector.", - } - - meta = document.document_metadata or {} - file_id = meta.get("onedrive_file_id") - document_id = document.id - - if not file_id: - return { - "status": "error", - "message": "File ID is missing. Please re-index the file.", - } - - conn_result = await db_session.execute( - select(SearchSourceConnector).filter( - and_( - SearchSourceConnector.id == document.connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + if not document: + doc_result = await db_session.execute( + select(Document) + .join( + SearchSourceConnector, + Document.connector_id == SearchSourceConnector.id, + ) + .filter( + and_( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.ONEDRIVE_FILE, + func.lower( + cast( + Document.document_metadata[ + "onedrive_file_name" + ], + String, + ) + ) + == func.lower(file_name), + SearchSourceConnector.user_id == user_id, + ) + ) + .order_by(Document.updated_at.desc().nullslast()) + .limit(1) ) - ) - ) - connector = conn_result.scalars().first() - if not connector: - return { - "status": "error", - "message": "OneDrive connector not found or access denied.", - } + document = doc_result.scalars().first() - cfg = connector.config or {} - if cfg.get("auth_expired"): - return { - "status": "auth_error", - "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "onedrive", - } + if not document: + return { + "status": "not_found", + "message": ( + f"File '{file_name}' not found in your indexed OneDrive files. " + "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " + "or (3) the file name is different." + ), + } - context = { - "file": { - "file_id": file_id, - "name": file_name, - "document_id": document_id, - "web_url": meta.get("web_url"), - }, - "account": { - "id": connector.id, - "name": connector.name, - "user_email": cfg.get("user_email"), - }, - } + if not document.connector_id: + return { + "status": "error", + "message": "Document has no associated connector.", + } - result = request_approval( - action_type="onedrive_file_trash", - tool_name="delete_onedrive_file", - params={ - "file_id": file_id, - "connector_id": connector.id, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) + meta = document.document_metadata or {} + file_id = meta.get("onedrive_file_id") + document_id = document.id - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + if not file_id: + return { + "status": "error", + "message": "File ID is missing. Please re-index the file.", + } - final_file_id = result.params.get("file_id", file_id) - final_connector_id = result.params.get("connector_id", connector.id) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - if final_connector_id != connector.id: - result = await db_session.execute( + conn_result = await db_session.execute( select(SearchSourceConnector).filter( and_( - SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.connector_type @@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool( ) ) ) - validated_connector = result.scalars().first() - if not validated_connector: + connector = conn_result.scalars().first() + if not connector: return { "status": "error", - "message": "Selected OneDrive connector is invalid or has been disconnected.", + "message": "OneDrive connector not found or access denied.", } - actual_connector_id = validated_connector.id - else: - actual_connector_id = connector.id - logger.info( - f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" - ) + cfg = connector.config or {} + if cfg.get("auth_expired"): + return { + "status": "auth_error", + "message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "onedrive", + } - client = OneDriveClient( - session=db_session, connector_id=actual_connector_id - ) - await client.trash_file(final_file_id) + context = { + "file": { + "file_id": file_id, + "name": file_name, + "document_id": document_id, + "web_url": meta.get("web_url"), + }, + "account": { + "id": connector.id, + "name": connector.name, + "user_email": cfg.get("user_email"), + }, + } - logger.info( - f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" - ) - - trash_result: dict[str, Any] = { - "status": "success", - "file_id": final_file_id, - "message": f"Successfully moved '{file_name}' to the recycle bin.", - } - - deleted_from_kb = False - if final_delete_from_kb and document_id: - try: - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - doc = doc_result.scalars().first() - if doc: - await db_session.delete(doc) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - trash_result["warning"] = ( - f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" - ) - - trash_result["deleted_from_kb"] = deleted_from_kb - if deleted_from_kb: - trash_result["message"] = ( - f"{trash_result.get('message', '')} (also removed from knowledge base)" + result = request_approval( + action_type="onedrive_file_trash", + tool_name="delete_onedrive_file", + params={ + "file_id": file_id, + "connector_id": connector.id, + "delete_from_kb": delete_from_kb, + }, + context=context, ) - return trash_result + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + final_file_id = result.params.get("file_id", file_id) + final_connector_id = result.params.get("connector_id", connector.id) + final_delete_from_kb = result.params.get( + "delete_from_kb", delete_from_kb + ) + + if final_connector_id != connector.id: + result = await db_session.execute( + select(SearchSourceConnector).filter( + and_( + SearchSourceConnector.id == final_connector_id, + SearchSourceConnector.search_space_id + == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + ) + validated_connector = result.scalars().first() + if not validated_connector: + return { + "status": "error", + "message": "Selected OneDrive connector is invalid or has been disconnected.", + } + actual_connector_id = validated_connector.id + else: + actual_connector_id = connector.id + + logger.info( + f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" + ) + + client = OneDriveClient( + session=db_session, connector_id=actual_connector_id + ) + await client.trash_file(final_file_id) + + logger.info( + f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" + ) + + trash_result: dict[str, Any] = { + "status": "success", + "file_id": final_file_id, + "message": f"Successfully moved '{file_name}' to the recycle bin.", + } + + deleted_from_kb = False + if final_delete_from_kb and document_id: + try: + doc_result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + doc = doc_result.scalars().first() + if doc: + await db_session.delete(doc) + await db_session.commit() + deleted_from_kb = True + logger.info( + f"Deleted document {document_id} from knowledge base" + ) + else: + logger.warning(f"Document {document_id} not found in KB") + except Exception as e: + logger.error(f"Failed to delete document from KB: {e}") + await db_session.rollback() + trash_result["warning"] = ( + f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}" + ) + + trash_result["deleted_from_kb"] = deleted_from_kb + if deleted_from_kb: + trash_result["message"] = ( + f"{trash_result.get('message', '')} (also removed from knowledge base)" + ) + + return trash_result except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index e8bab36fd..b842d7a20 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -824,13 +824,22 @@ async def build_tools_async( """Async version of build_tools that also loads MCP tools from database. Design Note: - This function exists because MCP tools require database queries to load user configs, - while built-in tools are created synchronously from static code. + This function exists because MCP tools require database queries to load + user configs, while built-in tools are created synchronously from static + code. - Alternative: We could make build_tools() itself async and always query the database, - but that would force async everywhere even when only using built-in tools. The current - design keeps the simple case (static tools only) synchronous while supporting dynamic - database-loaded tools through this async wrapper. + Alternative: We could make build_tools() itself async and always query + the database, but that would force async everywhere even when only using + built-in tools. The current design keeps the simple case (static tools + only) synchronous while supporting dynamic database-loaded tools through + this async wrapper. + + Phase 1.3: built-in tool construction (CPU; runs in a thread pool to + avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on + the event loop) are kicked off concurrently. Cold-path savings are + bounded by the slower of the two — typically MCP at ~200ms-1.7s — + so the parallelization recovers the ~50-200ms previously spent + serially on built-in construction. Args: dependencies: Dict containing all possible dependencies @@ -843,33 +852,70 @@ async def build_tools_async( List of configured tool instances ready for the agent, including MCP tools. """ + import asyncio import time _perf_log = logging.getLogger("surfsense.perf") _perf_log.setLevel(logging.DEBUG) + can_load_mcp = ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ) + + # Built-in tool construction is synchronous + CPU-only. Off-loop it so + # MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure + # function over its inputs — safe to thread-shift. _t0 = time.perf_counter() - tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + builtin_task = asyncio.create_task( + asyncio.to_thread( + build_tools, dependencies, enabled_tools, disabled_tools, additional_tools + ) + ) + + mcp_task: asyncio.Task | None = None + if can_load_mcp: + mcp_task = asyncio.create_task( + load_mcp_tools( + dependencies["db_session"], + dependencies["search_space_id"], + ) + ) + + # Surface failures from each task independently so a flaky MCP + # endpoint never poisons built-in tool registration. ``return_exceptions`` + # gives us per-task exceptions instead of dropping the second result + # when the first raises. + if mcp_task is not None: + builtin_result, mcp_result = await asyncio.gather( + builtin_task, mcp_task, return_exceptions=True + ) + else: + builtin_result = await builtin_task + mcp_result = None + + if isinstance(builtin_result, BaseException): + raise builtin_result # built-in registration failure is non-recoverable + tools: list[BaseTool] = builtin_result _perf_log.info( - "[build_tools_async] Built-in tools in %.3fs (%d tools)", + "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(tools), ) - # Load MCP tools if requested and dependencies are available - if ( - include_mcp_tools - and "db_session" in dependencies - and "search_space_id" in dependencies - ): - try: - _t0 = time.perf_counter() - mcp_tools = await load_mcp_tools( - dependencies["db_session"], - dependencies["search_space_id"], + if mcp_task is not None: + if isinstance(mcp_result, BaseException): + # ``return_exceptions=True`` captures the exception out-of-band, + # so ``sys.exc_info()`` is empty here. Pass the captured + # exception via ``exc_info=`` to get a real traceback. + logging.error( + "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result ) + else: + mcp_tools = mcp_result or [] _perf_log.info( - "[build_tools_async] MCP tools loaded in %.3fs (%d tools)", + "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)", time.perf_counter() - _t0, len(mcp_tools), ) @@ -879,8 +925,6 @@ async def build_tools_async( len(mcp_tools), [t.name for t in mcp_tools], ) - except Exception as e: - logging.exception("Failed to load MCP tools: %s", e) logging.info( "Total tools for agent: %d — %s", diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py index b8b1527c7..2965f2f02 100644 --- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py +++ b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py @@ -15,7 +15,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument +from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker from app.utils.document_converters import embed_text @@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): """ Factory function to create the search_surfsense_docs tool. + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + Args: - db_session: Database session for executing queries + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. Returns: A configured tool function for searching Surfsense documentation """ + del db_session # per-call session — see docstring @tool async def search_surfsense_docs(query: str, top_k: int = 10) -> str: @@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession): Returns: Relevant documentation content formatted with chunk IDs for citations """ - return await search_surfsense_docs_async( - query=query, - db_session=db_session, - top_k=top_k, - ) + async with async_session_maker() as db_session: + return await search_surfsense_docs_async( + query=query, + db_session=db_session, + top_k=top_k, + ) return search_surfsense_docs diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py index d7b000853..0fc52b5c7 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_list_teams_channels_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the list_teams_channels tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured list_teams_channels tool + """ + del db_session # per-call session — see docstring + @tool async def list_teams_channels() -> dict[str, Any]: """List all Microsoft Teams and their channels the user has access to. @@ -23,63 +42,66 @@ def create_list_teams_channels_tool( Dictionary with status and a list of teams, each containing team_id, team_name, and a list of channels (id, name). """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get( - f"{GRAPH_API}/me/joinedTeams", headers=headers + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if teams_resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if teams_resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {teams_resp.status_code}", - } + token = await get_access_token(db_session, connector) + headers = {"Authorization": f"Bearer {token}"} - teams_data = teams_resp.json().get("value", []) - result_teams = [] - - async with httpx.AsyncClient(timeout=20.0) as client: - for team in teams_data: - team_id = team["id"] - ch_resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels", - headers=headers, - ) - channels = [] - if ch_resp.status_code == 200: - channels = [ - {"id": ch["id"], "name": ch.get("displayName", "")} - for ch in ch_resp.json().get("value", []) - ] - result_teams.append( - { - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - } + async with httpx.AsyncClient(timeout=20.0) as client: + teams_resp = await client.get( + f"{GRAPH_API}/me/joinedTeams", headers=headers ) - return { - "status": "success", - "teams": result_teams, - "total_teams": len(result_teams), - } + if teams_resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if teams_resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {teams_resp.status_code}", + } + + teams_data = teams_resp.json().get("value", []) + result_teams = [] + + async with httpx.AsyncClient(timeout=20.0) as client: + for team in teams_data: + team_id = team["id"] + ch_resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels", + headers=headers, + ) + channels = [] + if ch_resp.status_code == 200: + channels = [ + {"id": ch["id"], "name": ch.get("displayName", "")} + for ch in ch_resp.json().get("value", []) + ] + result_teams.append( + { + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + } + ) + + return { + "status": "success", + "teams": result_teams, + "total_teams": len(result_teams), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py index d24a7e4d3..0ebda021e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -5,6 +5,8 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker + from ._auth import GRAPH_API, get_access_token, get_teams_connector logger = logging.getLogger(__name__) @@ -15,6 +17,23 @@ def create_read_teams_messages_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the read_teams_messages tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured read_teams_messages tool + """ + del db_session # per-call session — see docstring + @tool async def read_teams_messages( team_id: str, @@ -32,65 +51,68 @@ def create_read_teams_messages_tool( Dictionary with status and a list of messages including id, sender, content, timestamp. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} limit = min(limit, 50) try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.get( - f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", - headers={"Authorization": f"Bearer {token}"}, - params={"$top": limit}, + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "error", - "message": "Insufficient permissions to read this channel.", - } - if resp.status_code != 200: - return { - "status": "error", - "message": f"Graph API error: {resp.status_code}", - } + token = await get_access_token(db_session, connector) - raw_msgs = resp.json().get("value", []) - messages = [] - for m in raw_msgs: - sender = m.get("from", {}) - user_info = sender.get("user", {}) if sender else {} - body = m.get("body", {}) - messages.append( - { - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", + headers={"Authorization": f"Bearer {token}"}, + params={"$top": limit}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "error", + "message": "Insufficient permissions to read this channel.", + } + if resp.status_code != 200: + return { + "status": "error", + "message": f"Graph API error: {resp.status_code}", } - ) - return { - "status": "success", - "team_id": team_id, - "channel_id": channel_id, - "messages": messages, - "total": len(messages), - } + raw_msgs = resp.json().get("value", []) + messages = [] + for m in raw_msgs: + sender = m.get("from", {}) + user_info = sender.get("user", {}) if sender else {} + body = m.get("body", {}) + messages.append( + { + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + } + ) + + return { + "status": "success", + "team_id": team_id, + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py index fd8d00870..6f40d27e1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -6,6 +6,7 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval +from app.db import async_session_maker from ._auth import GRAPH_API, get_access_token, get_teams_connector @@ -17,6 +18,23 @@ def create_send_teams_message_tool( search_space_id: int | None = None, user_id: str | None = None, ): + """ + Factory function to create the send_teams_message tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + + Args: + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + + Returns: + Configured send_teams_message tool + """ + del db_session # per-call session — see docstring + @tool async def send_teams_message( team_id: str, @@ -39,70 +57,73 @@ def create_send_teams_message_tool( IMPORTANT: - If status is "rejected", the user explicitly declined. Do NOT retry. """ - if db_session is None or search_space_id is None or user_id is None: + if search_space_id is None or user_id is None: return {"status": "error", "message": "Teams tool not properly configured."} try: - connector = await get_teams_connector(db_session, search_space_id, user_id) - if not connector: - return {"status": "error", "message": "No Teams connector found."} + async with async_session_maker() as db_session: + connector = await get_teams_connector( + db_session, search_space_id, user_id + ) + if not connector: + return {"status": "error", "message": "No Teams connector found."} - result = request_approval( - action_type="teams_send_message", - tool_name="send_teams_message", - params={ - "team_id": team_id, - "channel_id": channel_id, - "content": content, - }, - context={"connector_id": connector.id}, - ) - - if result.rejected: - return { - "status": "rejected", - "message": "User declined. Message was not sent.", - } - - final_content = result.params.get("content", content) - final_team = result.params.get("team_id", team_id) - final_channel = result.params.get("channel_id", channel_id) - - token = await get_access_token(db_session, connector) - - async with httpx.AsyncClient(timeout=20.0) as client: - resp = await client.post( - f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", - headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", + result = request_approval( + action_type="teams_send_message", + tool_name="send_teams_message", + params={ + "team_id": team_id, + "channel_id": channel_id, + "content": content, }, - json={"body": {"content": final_content}}, + context={"connector_id": connector.id}, ) - if resp.status_code == 401: - return { - "status": "auth_error", - "message": "Teams token expired. Please re-authenticate.", - "connector_type": "teams", - } - if resp.status_code == 403: - return { - "status": "insufficient_permissions", - "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", - } - if resp.status_code not in (200, 201): - return { - "status": "error", - "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", - } + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } - msg_data = resp.json() - return { - "status": "success", - "message_id": msg_data.get("id"), - "message": "Message sent to Teams channel.", - } + final_content = result.params.get("content", content) + final_team = result.params.get("team_id", team_id) + final_channel = result.params.get("channel_id", channel_id) + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"body": {"content": final_content}}, + ) + + if resp.status_code == 401: + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } + if resp.status_code == 403: + return { + "status": "insufficient_permissions", + "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", + } + if resp.status_code not in (200, 201): + return { + "status": "error", + "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", + } + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": "Message sent to Teams channel.", + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/update_memory.py b/surfsense_backend/app/agents/new_chat/tools/update_memory.py index 4128ac0dc..fbc9edbba 100644 --- a/surfsense_backend/app/agents/new_chat/tools/update_memory.py +++ b/surfsense_backend/app/agents/new_chat/tools/update_memory.py @@ -26,7 +26,7 @@ from langchain_core.tools import tool from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import SearchSpace, User +from app.db import SearchSpace, User, async_session_maker logger = logging.getLogger(__name__) @@ -295,6 +295,25 @@ def create_update_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the user-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + user_id: ID of the user whose memory document is being updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the user-memory scope. + """ + del db_session # per-call session — see docstring uid = UUID(user_id) if isinstance(user_id, str) else user_id @tool @@ -311,26 +330,26 @@ def create_update_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute(select(User).where(User.id == uid)) - user = result.scalars().first() - if not user: - return {"status": "error", "message": "User not found."} + async with async_session_maker() as db_session: + result = await db_session.execute(select(User).where(User.id == uid)) + user = result.scalars().first() + if not user: + return {"status": "error", "message": "User not found."} - old_memory = user.memory_md + old_memory = user.memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(user, "memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="memory", - scope="user", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr(user, "memory_md", content), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="memory", + scope="user", + ) except Exception as e: logger.exception("Failed to update user memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update memory: {e}", @@ -344,6 +363,27 @@ def create_update_team_memory_tool( db_session: AsyncSession, llm: Any | None = None, ): + """Factory function to create the team-memory update tool. + + The tool acquires its own short-lived ``AsyncSession`` per call via + :data:`async_session_maker` so the closure is safe to share across + HTTP requests by the compiled-agent cache. Capturing a per-request + session here would surface stale/closed sessions on cache hits. + The session's bound ``commit``/``rollback`` methods are captured at + call time, after ``async with`` has bound ``db_session`` locally. + + Args: + search_space_id: ID of the search space whose team memory is being + updated. + db_session: Reserved for registry compatibility. Per-call sessions + are opened via :data:`async_session_maker` inside the tool body. + llm: Optional LLM for the forced-rewrite path. + + Returns: + Configured update_memory tool for the team-memory scope. + """ + del db_session # per-call session — see docstring + @tool async def update_memory(updated_memory: str) -> dict[str, Any]: """Update the team's shared memory document for this search space. @@ -359,28 +399,30 @@ def create_update_team_memory_tool( updated_memory: The FULL updated markdown document (not a diff). """ try: - result = await db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - space = result.scalars().first() - if not space: - return {"status": "error", "message": "Search space not found."} + async with async_session_maker() as db_session: + result = await db_session.execute( + select(SearchSpace).where(SearchSpace.id == search_space_id) + ) + space = result.scalars().first() + if not space: + return {"status": "error", "message": "Search space not found."} - old_memory = space.shared_memory_md + old_memory = space.shared_memory_md - return await _save_memory( - updated_memory=updated_memory, - old_memory=old_memory, - llm=llm, - apply_fn=lambda content: setattr(space, "shared_memory_md", content), - commit_fn=db_session.commit, - rollback_fn=db_session.rollback, - label="team memory", - scope="team", - ) + return await _save_memory( + updated_memory=updated_memory, + old_memory=old_memory, + llm=llm, + apply_fn=lambda content: setattr( + space, "shared_memory_md", content + ), + commit_fn=db_session.commit, + rollback_fn=db_session.rollback, + label="team memory", + scope="team", + ) except Exception as e: logger.exception("Failed to update team memory: %s", e) - await db_session.rollback() return { "status": "error", "message": f"Failed to update team memory: {e}", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 14d7f4d23..2c9b4f390 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -421,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None: OpenRouterIntegrationService.get_instance().stop_background_refresh() +async def _warm_agent_jit_caches() -> None: + """Pay the LangChain / LangGraph / Deepagents JIT cost at startup. + + Why + ---- + A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema + generation chain takes 1.5-2 seconds of pure CPU on first invocation + inside any Python process: the graph compiler builds reducers, + Pydantic v2 generates and JITs validator schemas, deepagents + eagerly compiles its general-purpose subagent, etc. Subsequent + compiles in the same process pay only ~50% of that cost (the lazy + JIT bits are cached in module-level dicts). + + Doing one throwaway compile during ``lifespan`` startup pre-pays + that cost so the *first real request* doesn't. We do NOT prime + :mod:`agent_cache` because the cache key requires real + ``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the + throwaway agent is genuinely thrown away and immediately collected. + + Safety + ------ + * No DB access. We construct a stub LLM (no real keys), pass an + empty tools list, and pass ``checkpointer=None`` so we never + touch Postgres. + * Bounded by ``asyncio.wait_for`` so a hang here can never block + worker startup. On any failure, we log + swallow — the worst + case is the first real request pays the full cold cost (i.e. + pre-warmup behaviour). + """ + import time as _time + + logger = logging.getLogger(__name__) + t0 = _time.perf_counter() + try: + from langchain.agents import create_agent + from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + TodoListMiddleware, + ToolCallLimitMiddleware, + ) + from langchain_core.language_models.fake_chat_models import ( + FakeListChatModel, + ) + from langchain_core.tools import tool + + from app.agents.new_chat.context import SurfSenseContextSchema + + # Minimal LLM stub. ``FakeListChatModel`` satisfies + # ``BaseChatModel`` without any network or auth — perfect for + # exercising the compile path without side effects. + stub_llm = FakeListChatModel(responses=["warmup-response"]) + + # Two trivial tools with arg + return schemas — exercises the + # Pydantic v2 schema JIT path. Without at least one tool the + # graph compile skips the tool-loop bytecode generation that + # accounts for ~30-50% of cold compile cost. + @tool + def _warmup_tool_a(query: str, limit: int = 5) -> str: + """Warmup tool A — never actually invoked.""" + return query[:limit] + + @tool + def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]: + """Warmup tool B — never actually invoked.""" + return {"name": name, "value": value} + + # A handful of common middleware so the compile pre-pays the + # ``AgentMiddleware`` resolver path. These instances never run + # because the throwaway agent is immediately collected. + # ``SubAgentMiddleware`` is the single heaviest line in cold + # ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to + # compile its general-purpose subagent's full inner graph), + # so we include it here to make sure that compile path is JIT'd. + warmup_middleware: list = [ + TodoListMiddleware(), + ModelCallLimitMiddleware( + thread_limit=120, run_limit=80, exit_behavior="end" + ), + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ), + ] + try: + from deepagents import SubAgentMiddleware + from deepagents.backends import StateBackend + from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT + + gp_warmup_spec = { # type: ignore[var-annotated] + **GENERAL_PURPOSE_SUBAGENT, + "model": stub_llm, + "tools": [_warmup_tool_a], + "middleware": [TodoListMiddleware()], + } + warmup_middleware.append( + SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec]) + ) + except Exception: + # Deepagents missing/incompatible — middleware-only warmup + # still produces a useful (smaller) speedup. + logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True) + + compiled = create_agent( + stub_llm, + tools=[_warmup_tool_a, _warmup_tool_b], + system_prompt="You are a warmup stub.", + middleware=warmup_middleware, + context_schema=SurfSenseContextSchema, + checkpointer=None, + ) + + # Touch the compiled graph's stream_channels / nodes so any + # remaining lazy schema work fires now instead of on first + # real invocation. + _ = list(getattr(compiled, "nodes", {}).keys()) + + del compiled + logger.info( + "[startup] Agent JIT warmup completed in %.3fs", + _time.perf_counter() - t0, + ) + except Exception: + logger.warning( + "[startup] Agent JIT warmup failed in %.3fs (non-fatal — first " + "real request will pay the full compile cost)", + _time.perf_counter() - t0, + exc_info=True, + ) + + @asynccontextmanager async def lifespan(app: FastAPI): # Tune GC: lower gen-2 threshold so long-lived garbage is collected @@ -445,6 +574,18 @@ async def lifespan(app: FastAPI): "Docs will be indexed on the next restart." ) + # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays + # worker readiness. ``shield`` so Uvicorn cancelling startup + # doesn't leave half-warmed Pydantic schemas in an inconsistent + # state. + try: + await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20) + except (TimeoutError, Exception): # pragma: no cover - defensive + logging.getLogger(__name__).warning( + "[startup] Agent JIT warmup hit timeout/error — skipping; " + "first real request will pay the full compile cost." + ) + log_system_snapshot("startup_complete") yield diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 7c55da2e5..45bcfd00f 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -1,6 +1,8 @@ import asyncio +import os import time from datetime import datetime +from threading import Lock from typing import Any import httpx @@ -2769,12 +2771,22 @@ class ConnectorService: """ Get all available (enabled) connector types for a search space. + Phase 1.4: results are cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session + identity — the cached value is plain data, safe to share across + requests. Invalidate on connector add/update/delete via + :func:`invalidate_connector_discovery_cache`. + Args: search_space_id: The search space ID Returns: List of SearchSourceConnectorType enums for enabled connectors """ + cached = _get_cached_connectors(search_space_id) + if cached is not None: + return list(cached) + query = ( select(SearchSourceConnector.connector_type) .filter( @@ -2784,8 +2796,9 @@ class ConnectorService: ) result = await self.session.execute(query) - connector_types = result.scalars().all() - return list(connector_types) + connector_types = list(result.scalars().all()) + _set_cached_connectors(search_space_id, connector_types) + return connector_types async def get_available_document_types( self, @@ -2794,12 +2807,22 @@ class ConnectorService: """ Get all document types that have at least one document in the search space. + Phase 1.4: cached per ``search_space_id`` for + :data:`_DISCOVERY_TTL_SECONDS`. Invalidate via + :func:`invalidate_connector_discovery_cache` when a connector + finishes indexing new documents (or document types are otherwise + added/removed). + Args: search_space_id: The search space ID Returns: List of document type strings that have documents indexed """ + cached = _get_cached_doc_types(search_space_id) + if cached is not None: + return list(cached) + from sqlalchemy import distinct from app.db import Document @@ -2809,5 +2832,164 @@ class ConnectorService: ) result = await self.session.execute(query) - doc_types = result.scalars().all() - return [str(dt) for dt in doc_types] + doc_types = [str(dt) for dt in result.scalars().all()] + _set_cached_doc_types(search_space_id, doc_types) + return doc_types + + +# --------------------------------------------------------------------------- +# Connector / document-type discovery TTL cache (Phase 1.4) +# --------------------------------------------------------------------------- +# +# Both ``get_available_connectors`` and ``get_available_document_types`` are +# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query +# hits Postgres and contributes to per-turn agent build latency. Their +# results change infrequently — only when the user adds/edits/removes a +# connector, or when an indexer commits a new document type. A short TTL +# cache (default 30s, env-tunable) collapses N concurrent calls into one +# DB roundtrip with bounded staleness. +# +# Invalidation: connector mutation routes (create / update / delete) call +# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the +# entry for the affected space. Multi-replica deployments still pay one +# DB roundtrip per replica per TTL window, which is fine — staleness is +# bounded and the alternative (cross-replica fanout) is not worth the +# coupling here. + +_DISCOVERY_TTL_SECONDS: float = float( + os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30") +) + +# Per-search-space caches. Keyed by ``search_space_id``; value is +# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock — +# read-mostly workload, sub-microsecond contention. +_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {} +_doc_types_cache: dict[int, tuple[float, list[str]]] = {} +_cache_lock = Lock() + + +def _get_cached_connectors( + search_space_id: int, +) -> list[SearchSourceConnectorType] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _connectors_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _connectors_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_connectors( + search_space_id: int, payload: list[SearchSourceConnectorType] +) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _connectors_cache[search_space_id] = (expires_at, list(payload)) + + +def _get_cached_doc_types(search_space_id: int) -> list[str] | None: + if _DISCOVERY_TTL_SECONDS <= 0: + return None + with _cache_lock: + entry = _doc_types_cache.get(search_space_id) + if entry is None: + return None + expires_at, payload = entry + if time.monotonic() >= expires_at: + _doc_types_cache.pop(search_space_id, None) + return None + return payload + + +def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None: + if _DISCOVERY_TTL_SECONDS <= 0: + return + expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS + with _cache_lock: + _doc_types_cache[search_space_id] = (expires_at, list(payload)) + + +def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None: + """Drop cached discovery results for ``search_space_id`` (or all spaces). + + Connector CRUD routes / indexer pipelines call this when they mutate + the rows backing :func:`ConnectorService.get_available_connectors` / + :func:`get_available_document_types`. ``None`` clears every space — + useful in tests and on bulk imports. + """ + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + _doc_types_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + _doc_types_cache.pop(search_space_id, None) + + +def _invalidate_connectors_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _connectors_cache.clear() + else: + _connectors_cache.pop(search_space_id, None) + + +def _invalidate_doc_types_only(search_space_id: int | None = None) -> None: + with _cache_lock: + if search_space_id is None: + _doc_types_cache.clear() + else: + _doc_types_cache.pop(search_space_id, None) + + +def _register_invalidation_listeners() -> None: + """Wire SQLAlchemy ORM events so cache stays consistent automatically. + + Listening on ``after_insert`` / ``after_update`` / ``after_delete`` + means every successful INSERT/UPDATE/DELETE that goes through the ORM + invalidates the affected search space's cached discovery payload — + no need to sprinkle ``invalidate_*`` calls across 30+ connector + routes. Bulk operations that bypass the ORM (e.g. + ``session.execute(insert(...))`` without a mapped object) still need + explicit invalidation; document indexers already commit through the + ORM so document-type discovery is covered. + """ + from sqlalchemy import event + + # Imported here (not at module top) to avoid a circular import: + # app.services.connector_service is itself imported from app.db's + # ecosystem indirectly via several CRUD modules. + from app.db import Document, SearchSourceConnector + + def _connector_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_connectors_only(int(sid)) + + def _document_changed(_mapper, _connection, target) -> None: + sid = getattr(target, "search_space_id", None) + if sid is not None: + _invalidate_doc_types_only(int(sid)) + + for evt in ("after_insert", "after_update", "after_delete"): + event.listen(SearchSourceConnector, evt, _connector_changed) + event.listen(Document, evt, _document_changed) + + +try: + _register_invalidation_listeners() +except Exception: # pragma: no cover - defensive; never block module import + import logging as _logging + + _logging.getLogger(__name__).exception( + "Failed to register connector discovery cache invalidation listeners; " + "stale cache risk: explicit invalidate_connector_discovery_cache calls " + "may be required." + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 268a4401e..f7ddd8909 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.errors import BusyError from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection @@ -559,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None: ) +async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None: + """Wait for a discarded speculative agent build to release shared state. + + Used by the parallel preflight + agent-build path. The speculative build + closes over the request-scoped ``AsyncSession`` (for the brief connector + discovery / tool-factory window before its CPU work moves into a worker + thread). If preflight reports a 429 we want to fall back to the original + repin → reload → rebuild path, but we MUST NOT touch ``session`` again + until any in-flight session work owned by the speculative build has + fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not + concurrency-safe and the same hazard cost us a hard ``InvalidRequestError`` + earlier in this PR (see ``connector_service`` parallel-gather revert). + + We simply ``await`` the task and swallow any exception: in this path the + build's outcome is irrelevant — success populates the agent cache (a free + side effect), failure is discarded. The wasted CPU is acceptable since + 429 fallbacks are rare and the original sequential code also paid the + full build cost on the same path. + """ + with contextlib.suppress(BaseException): + await task + + def _classify_stream_exception( exc: Exception, *, @@ -696,6 +720,7 @@ async def _stream_agent_events( fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, fallback_commit_thread_id: int | None = None, + runtime_context: Any = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -801,7 +826,18 @@ async def _stream_agent_events( return event return None - async for event in agent.astream_events(input_data, config=config, version="v2"): + # Per-invocation runtime context (Phase 1.5). When supplied, + # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids`` + # from ``runtime.context`` instead of its constructor closure — the + # prerequisite that lets the compiled-agent cache (Phase 1) reuse a + # single graph across turns. Astream_events_kwargs stays empty when + # callers leave ``runtime_context`` as ``None`` to preserve the + # legacy code path bit-for-bit. + astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} + if runtime_context is not None: + astream_kwargs["context"] = runtime_context + + async for event in agent.astream_events(input_data, **astream_kwargs): event_type = event.get("event", "") if event_type == "on_chat_model_stream": @@ -2560,23 +2596,102 @@ async def stream_new_chat( # Detecting a 429 here lets us repin BEFORE the planner/classifier/ # title-generation LLM calls fan out and each independently hit the # same upstream rate limit. - if ( + # + # PERF: preflight is a network round-trip to the LLM provider (~1-5s) + # and is independent of the agent build (CPU-bound, ~5-7s). They used + # to run sequentially → ``preflight + build`` on cold cache = 11.5s. + # We now kick off preflight as a background task FIRST, then run the + # synchronous setup work and the agent build in parallel. In the + # success path (the common case) total wall time drops to roughly + # ``max(preflight, build)`` — the preflight finishes during the + # agent compile and we just consume its result. In the rare 429 + # path the speculative build is awaited to completion (so its + # session usage is fully released) via + # :func:`_settle_speculative_agent_build`, then discarded, and + # we fall back to the original repin-and-rebuild flow. + preflight_needed = ( requested_llm_config_id == 0 and llm_config_id < 0 and not is_recently_healthy(llm_config_id) - ): + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight:{llm_config_id}", + ) + + # Create connector service + _t0 = time.perf_counter() + connector_service = ConnectorService(session, search_space_id=search_space_id) + + firecrawl_api_key = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + _perf_log.info( + "[stream_new_chat] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + # Get the PostgreSQL checkpointer for persistent conversation memory + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + _t0 = time.perf_counter() + # Speculative agent build — runs in parallel with the preflight + # task (if any). Built with the *current* ``llm`` / ``agent_config``; + # if preflight reports 429 we will discard this future and rebuild + # against the freshly pinned config below. + agent_build_task = asyncio.create_task( + create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ), + name="agent_build:stream_new_chat", + ) + + agent: Any = None + if preflight_task is not None: try: - await _preflight_llm(llm) + await preflight_task mark_healthy(llm_config_id) _perf_log.info( - "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", llm_config_id, time.perf_counter() - _t_preflight, ) except Exception as preflight_exc: + # Both branches below need the session: the non-429 path + # may unwind via cleanup that uses ``session``, and the + # 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id`` + # against it. Wait for the speculative build to release its + # session usage before we proceed. + await _settle_speculative_agent_build(agent_build_task) if not _is_provider_rate_limited(preflight_exc): raise + # 429: speculative agent is discarded; run the original + # repin → reload → rebuild path against the freshly + # pinned config. previous_config_id = llm_config_id mark_runtime_cooldown( previous_config_id, reason="preflight_rate_limited" @@ -2639,46 +2754,28 @@ async def stream_new_chat( "fallback_config_id": llm_config_id, }, ) + # Rebuild against the new llm/agent_config. Sequential + # here because we no longer have anything to overlap with. + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) - # Create connector service - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_new_chat] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - # Get the PostgreSQL checkpointer for persistent conversation memory - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - disabled_tools=disabled_tools, - mentioned_document_ids=mentioned_document_ids, - filesystem_selection=filesystem_selection, - ) + if agent is None: + # Either no preflight was needed, or preflight succeeded — + # in both cases the speculative build is the agent we want. + agent = await agent_build_task _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -3005,6 +3102,18 @@ async def stream_new_chat( title_emitted = False + # Build the per-invocation runtime context (Phase 1.5). + # ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware`` + # via ``runtime.context.mentioned_document_ids`` instead of its + # ``__init__`` closure — that way the same compiled-agent instance + # can serve multiple turns with different mention lists. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list(mentioned_document_ids or []), + request_id=request_id, + turn_id=stream_result.turn_id, + ) + _t_stream_start = time.perf_counter() _first_event_logged = False runtime_rate_limit_recovered = False @@ -3028,6 +3137,7 @@ async def stream_new_chat( else FilesystemMode.CLOUD ), fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, ): if not _first_event_logged: _perf_log.info( @@ -3643,21 +3753,75 @@ async def stream_resume_chat( # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: # one cheap probe before the agent is rebuilt so a 429'd pin gets # repinned without burning planner/classifier/title calls first. - if ( + # See ``stream_new_chat`` for the full rationale on the speculative + # parallel build pattern below. + preflight_needed = ( requested_llm_config_id == 0 and llm_config_id < 0 and not is_recently_healthy(llm_config_id) - ): + ) + preflight_task: asyncio.Task[None] | None = None + _t_preflight = 0.0 + if preflight_needed: _t_preflight = time.perf_counter() + preflight_task = asyncio.create_task( + _preflight_llm(llm), + name=f"auto_pin_preflight_resume:{llm_config_id}", + ) + + _t0 = time.perf_counter() + connector_service = ConnectorService(session, search_space_id=search_space_id) + + firecrawl_api_key = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + _perf_log.info( + "[stream_resume] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + + _t0 = time.perf_counter() + agent_build_task = asyncio.create_task( + create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ), + name="agent_build:stream_resume", + ) + + agent: Any = None + if preflight_task is not None: try: - await _preflight_llm(llm) + await preflight_task mark_healthy(llm_config_id) _perf_log.info( - "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)", llm_config_id, time.perf_counter() - _t_preflight, ) except Exception as preflight_exc: + # Same session-safety rationale as ``stream_new_chat``. + await _settle_speculative_agent_build(agent_build_task) if not _is_provider_rate_limited(preflight_exc): raise previous_config_id = llm_config_id @@ -3717,43 +3881,22 @@ async def stream_resume_chat( "fallback_config_id": llm_config_id, }, ) + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) - _t0 = time.perf_counter() - connector_service = ConnectorService(session, search_space_id=search_space_id) - - firecrawl_api_key = None - webcrawler_connector = await connector_service.get_connector_by_type( - SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id - ) - if webcrawler_connector and webcrawler_connector.config: - firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - _perf_log.info( - "[stream_resume] Connector service + firecrawl key in %.3fs", - time.perf_counter() - _t0, - ) - - _t0 = time.perf_counter() - checkpointer = await get_checkpointer() - _perf_log.info( - "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 - ) - - visibility = thread_visibility or ChatVisibility.PRIVATE - - _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( - llm=llm, - search_space_id=search_space_id, - db_session=session, - connector_service=connector_service, - checkpointer=checkpointer, - user_id=user_id, - thread_id=chat_id, - agent_config=agent_config, - firecrawl_api_key=firecrawl_api_key, - thread_visibility=visibility, - filesystem_selection=filesystem_selection, - ) + if agent is None: + agent = await agent_build_task _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 ) @@ -3794,6 +3937,16 @@ async def stream_resume_chat( ) yield streaming_service.format_data("turn-status", {"status": "busy"}) + # Resume path doesn't carry new ``mentioned_document_ids`` — + # those are seeded in the original turn. We still pass a + # context so future middleware extensions (Phase 2) can rely on + # ``runtime.context`` always being populated. + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + request_id=request_id, + turn_id=stream_result.turn_id, + ) + _t_stream_start = time.perf_counter() _first_event_logged = False runtime_rate_limit_recovered = False @@ -3814,6 +3967,7 @@ async def stream_resume_chat( else FilesystemMode.CLOUD ), fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py new file mode 100644 index 000000000..9b3de2db7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py @@ -0,0 +1,268 @@ +"""Regression tests for the compiled-agent cache. + +Covers the cache primitive itself (TTL, LRU, in-flight de-duplication, +build-failure non-caching) and the cache-key signature helpers that +``create_surfsense_deep_agent`` relies on. The integration with +``create_surfsense_deep_agent`` is covered separately by the streaming +contract tests; this module focuses on the primitives so a regression +in the cache implementation is caught before it reaches the agent +factory. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +import pytest + +from app.agents.new_chat.agent_cache import ( + flags_signature, + reload_for_tests, + stable_hash, + system_prompt_hash, + tools_signature, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# stable_hash + signature helpers +# --------------------------------------------------------------------------- + + +def test_stable_hash_is_deterministic_across_calls() -> None: + a = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + b = stable_hash("v1", 42, "thread-9", None, ["x", "y"]) + assert a == b + + +def test_stable_hash_changes_when_any_part_changes() -> None: + base = stable_hash("v1", 42, "thread-9") + assert stable_hash("v1", 42, "thread-10") != base + assert stable_hash("v2", 42, "thread-9") != base + assert stable_hash("v1", 43, "thread-9") != base + + +def test_tools_signature_keys_on_name_and_description_not_identity() -> None: + """Two tool lists with the same surface must hash identically. + + The cache key MUST NOT change when the underlying ``BaseTool`` + instances are different Python objects (a fresh request constructs + fresh tool instances every time). Hashing on ``(name, description)`` + keeps the cache hot across requests with identical tool surfaces. + """ + + @dataclass + class FakeTool: + name: str + description: str + + tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")] + tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")] + sig_a = tools_signature( + tools_a, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + sig_b = tools_signature( + tools_b, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_a == sig_b, "tool order must not affect the signature" + + # Adding a tool rotates the key. + tools_c = [*tools_a, FakeTool("gamma", "does gamma")] + sig_c = tools_signature( + tools_c, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + assert sig_c != sig_a + + +def test_tools_signature_rotates_when_connector_set_changes() -> None: + @dataclass + class FakeTool: + name: str + description: str + + tools = [FakeTool("a", "x")] + base = tools_signature( + tools, available_connectors=["NOTION"], available_document_types=["FILE"] + ) + added = tools_signature( + tools, + available_connectors=["NOTION", "SLACK"], + available_document_types=["FILE"], + ) + assert base != added, "adding a connector must rotate the cache key" + + +def test_flags_signature_changes_when_flag_flips() -> None: + @dataclass(frozen=True) + class Flags: + a: bool = True + b: bool = False + + base = flags_signature(Flags()) + flipped = flags_signature(Flags(b=True)) + assert base != flipped + + +def test_system_prompt_hash_is_stable_and_distinct() -> None: + p1 = "You are a helpful assistant." + p2 = "You are a helpful assistant!" # one-character delta + assert system_prompt_hash(p1) == system_prompt_hash(p1) + assert system_prompt_hash(p1) != system_prompt_hash(p2) + + +# --------------------------------------------------------------------------- +# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cache_hit_returns_same_instance_on_second_call() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is b, "cache must return the SAME object across hits" + assert builds == 1, "builder must run exactly once" + + +@pytest.mark.asyncio +async def test_cache_different_keys_get_different_instances() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k1", builder=builder) + b = await cache.get_or_build("k2", builder=builder) + assert a is not b + + +@pytest.mark.asyncio +async def test_cache_stale_entries_get_rebuilt() -> None: + # ttl=0 means every read sees the entry as immediately stale. + cache = reload_for_tests(maxsize=8, ttl_seconds=0.0) + builds = 0 + + async def builder() -> object: + nonlocal builds + builds += 1 + return object() + + a = await cache.get_or_build("k", builder=builder) + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "stale entry must rebuild a fresh instance" + assert builds == 2 + + +@pytest.mark.asyncio +async def test_cache_evicts_lru_when_full() -> None: + cache = reload_for_tests(maxsize=2, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("a", builder=builder) + _ = await cache.get_or_build("b", builder=builder) + # Re-touch "a" so "b" is now the LRU victim. + a_again = await cache.get_or_build("a", builder=builder) + assert a_again is a + # Inserting "c" should evict "b" (LRU), not "a". + _ = await cache.get_or_build("c", builder=builder) + assert cache.stats()["size"] == 2 + + # Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild). + a_hit = await cache.get_or_build("a", builder=builder) + assert a_hit is a, "LRU must keep the most-recently-used 'a' entry" + + +@pytest.mark.asyncio +async def test_cache_concurrent_misses_coalesce_to_single_build() -> None: + """Two concurrent get_or_build calls on the same key must share one builder.""" + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + build_started = asyncio.Event() + builds = 0 + + async def slow_builder() -> object: + nonlocal builds + builds += 1 + build_started.set() + # Yield control so the second waiter can race against us. + await asyncio.sleep(0.05) + return object() + + task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + # Wait until the first builder has started, then race a second waiter. + await build_started.wait() + task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder)) + + a, b = await asyncio.gather(task_a, task_b) + assert a is b, "coalesced waiters must observe the same value" + assert builds == 1, "concurrent cold misses must collapse to ONE build" + + +@pytest.mark.asyncio +async def test_cache_does_not_store_failed_builds() -> None: + """A builder that raises must NOT poison the cache. + + The next caller for the same key must run the builder again (not + re-raise the cached exception). + """ + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + attempts = 0 + + async def flaky_builder() -> object: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise RuntimeError("transient") + return object() + + with pytest.raises(RuntimeError, match="transient"): + await cache.get_or_build("k", builder=flaky_builder) + + # Second call must retry — not re-raise the cached exception. + value = await cache.get_or_build("k", builder=flaky_builder) + assert value is not None + assert attempts == 2 + + +@pytest.mark.asyncio +async def test_cache_invalidate_drops_entry() -> None: + cache = reload_for_tests(maxsize=8, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + a = await cache.get_or_build("k", builder=builder) + assert cache.invalidate("k") is True + b = await cache.get_or_build("k", builder=builder) + assert a is not b, "post-invalidation lookup must rebuild" + + +@pytest.mark.asyncio +async def test_cache_invalidate_prefix_drops_matching_entries() -> None: + cache = reload_for_tests(maxsize=16, ttl_seconds=60.0) + + async def builder() -> object: + return object() + + await cache.get_or_build("user:1:thread:1", builder=builder) + await cache.get_or_build("user:1:thread:2", builder=builder) + await cache.get_or_build("user:2:thread:1", builder=builder) + + removed = cache.invalidate_prefix("user:1:") + assert removed == 2 + assert cache.stats()["size"] == 1 + + # The user:2 entry must still be hot (no rebuild). + survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder) + assert survivor_value is not None diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index df60a4816..6800be2af 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -34,6 +34,8 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", + "SURFSENSE_ENABLE_AGENT_CACHE", + "SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", ]: monkeypatch.delenv(name, raising=False) @@ -62,6 +64,11 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> assert flags.enable_stream_parity_v2 is True assert flags.enable_plugin_loader is False assert flags.enable_otel is False + # Phase 2: agent cache is now default-on (the prerequisite tool + # ``db_session`` refactor landed). The companion gp-subagent share + # flag stays default-off pending data on cold-miss frequency. + assert flags.enable_agent_cache is True + assert flags.enable_agent_cache_share_gp_subagent is False assert flags.any_new_middleware_enabled() is True diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py new file mode 100644 index 000000000..6c323d920 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py @@ -0,0 +1,344 @@ +"""Tests for ``FlattenSystemMessageMiddleware``. + +The middleware exists to defend against Anthropic's "Found 5 cache_control +blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on +the system message and the OpenRouter→Anthropic adapter redistributes +``cache_control`` across all of them. The flattening collapses every +all-text system content list to a single string before the LLM call. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage, SystemMessage + +from app.agents.new_chat.middleware.flatten_system import ( + FlattenSystemMessageMiddleware, + _flatten_text_blocks, + _flattened_request, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# _flatten_text_blocks — pure helper, the heart of the middleware. +# --------------------------------------------------------------------------- + + +class TestFlattenTextBlocks: + def test_joins_text_blocks_with_double_newline(self) -> None: + blocks = [ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + assert ( + _flatten_text_blocks(blocks) + == "\n\n\n\n" + ) + + def test_handles_single_text_block(self) -> None: + blocks = [{"type": "text", "text": "only one"}] + assert _flatten_text_blocks(blocks) == "only one" + + def test_handles_empty_list(self) -> None: + assert _flatten_text_blocks([]) == "" + + def test_passes_through_bare_string_blocks(self) -> None: + # LangChain content can mix bare strings and dict blocks. + blocks = ["raw string", {"type": "text", "text": "dict block"}] + assert _flatten_text_blocks(blocks) == "raw string\n\ndict block" + + def test_returns_none_for_image_block(self) -> None: + # System messages with images are rare — but we never want to + # silently lose the image payload by joining as text. + blocks = [ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png..."}}, + ] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_for_non_dict_non_str_block(self) -> None: + blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item] + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_field_missing(self) -> None: + blocks = [{"type": "text"}] # no ``text`` key + assert _flatten_text_blocks(blocks) is None + + def test_returns_none_when_text_is_not_string(self) -> None: + blocks = [{"type": "text", "text": ["nested", "list"]}] + assert _flatten_text_blocks(blocks) is None + + def test_drops_cache_control_from_inner_blocks(self) -> None: + # The whole point: existing cache_control on inner blocks is + # discarded so LiteLLM's ``cache_control_injection_points`` can + # re-attach exactly one breakpoint after flattening. + blocks = [ + {"type": "text", "text": "first"}, + { + "type": "text", + "text": "second", + "cache_control": {"type": "ephemeral"}, + }, + ] + flattened = _flatten_text_blocks(blocks) + assert flattened == "first\n\nsecond" + assert "cache_control" not in flattened # type: ignore[operator] + + +# --------------------------------------------------------------------------- +# _flattened_request — decides when to override and when to no-op. +# --------------------------------------------------------------------------- + + +def _make_request(system_message: SystemMessage | None) -> Any: + """Build a minimal ModelRequest stub. We only need .system_message + and .override(system_message=...) — the middleware never touches + other fields. + """ + request = MagicMock() + request.system_message = system_message + + def override(**kwargs: Any) -> Any: + new_request = MagicMock() + new_request.system_message = kwargs.get( + "system_message", request.system_message + ) + new_request.messages = kwargs.get("messages", getattr(request, "messages", [])) + new_request.tools = kwargs.get("tools", getattr(request, "tools", [])) + return new_request + + request.override = override + return request + + +class TestFlattenedRequest: + def test_collapses_multi_block_system_to_string(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message, SystemMessage) + assert flattened.system_message.content == ( + "\n\n\n\n\n\n\n\n" + ) + + def test_no_op_for_string_content(self) -> None: + sys = SystemMessage(content="already a string") + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_for_single_block_list(self) -> None: + # One block already produces one breakpoint — no need to flatten. + sys = SystemMessage(content=[{"type": "text", "text": "single"}]) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_no_op_when_system_message_missing(self) -> None: + request = _make_request(None) + assert _flattened_request(request) is None + + def test_no_op_when_list_contains_non_text_block(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ] + ) + request = _make_request(sys) + assert _flattened_request(request) is None + + def test_preserves_additional_kwargs_and_metadata(self) -> None: + # Defensive: nothing in the current chain sets these on a system + # message, but losing them silently when something does in the + # future would be a regression. ``name`` in particular is the only + # ``additional_kwargs`` field that ChatLiteLLM's + # ``_convert_message_to_dict`` propagates onto the wire. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ], + additional_kwargs={"name": "surfsense_system", "x": 1}, + response_metadata={"tokens": 42}, + ) + sys.id = "sys-msg-1" + request = _make_request(sys) + + flattened = _flattened_request(request) + assert flattened is not None + assert flattened.system_message.content == "a\n\nb" + assert flattened.system_message.additional_kwargs == { + "name": "surfsense_system", + "x": 1, + } + assert flattened.system_message.response_metadata == {"tokens": 42} + assert flattened.system_message.id == "sys-msg-1" + + def test_idempotent_when_run_twice(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + request = _make_request(sys) + first = _flattened_request(request) + assert first is not None + + # Second pass on the already-flattened request should be a no-op. + # We re-wrap in a request stub since the helper inspects + # ``request.system_message.content``. + second_request = _make_request(first.system_message) + assert _flattened_request(second_request) is None + + +# --------------------------------------------------------------------------- +# Middleware integration — verify the handler sees a flattened request. +# --------------------------------------------------------------------------- + + +class TestMiddlewareWrap: + @pytest.mark.asyncio + async def test_async_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = await mw.awrap_model_call(request, handler) + + assert result == "ok" + assert isinstance(captured["request"].system_message, SystemMessage) + assert captured["request"].system_message.content == "alpha\n\nbeta" + + @pytest.mark.asyncio + async def test_async_passes_through_when_already_string(self) -> None: + sys = SystemMessage(content="just a string") + request = _make_request(sys) + captured: dict[str, Any] = {} + + async def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + await mw.awrap_model_call(request, handler) + + # Same request object: no override happened. + assert captured["request"] is request + + def test_sync_passes_flattened_request_to_handler(self) -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "text", "text": "beta"}, + ] + ) + request = _make_request(sys) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + result = mw.wrap_model_call(request, handler) + + assert result == "ok" + assert captured["request"].system_message.content == "alpha\n\nbeta" + + def test_sync_passes_through_when_no_system_message(self) -> None: + request = _make_request(None) + captured: dict[str, Any] = {} + + def handler(req: Any) -> str: + captured["request"] = req + return "ok" + + mw = FlattenSystemMessageMiddleware() + mw.wrap_model_call(request, handler) + assert captured["request"] is request + + +# --------------------------------------------------------------------------- +# Regression guard — pin the worst-case shape that triggered the +# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the +# downstream cache_control_injection_points can only place 1 breakpoint +# on the system message regardless of provider redistribution quirks. +# --------------------------------------------------------------------------- + + +def test_regression_five_block_system_collapses_to_one_block() -> None: + sys = SystemMessage( + content=[ + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + {"type": "text", "text": ""}, + ] + ) + request = _make_request(sys) + flattened = _flattened_request(request) + + assert flattened is not None + assert isinstance(flattened.system_message.content, str) + # The exact join doesn't matter for the cache_control accounting — + # only that there is exactly ONE content block when LiteLLM's + # AnthropicCacheControlHook later targets ``role: system``. + assert " None: + # Sanity: the middleware MUST NOT touch user messages — only the + # system message. Multi-block user content is the path that carries + # image attachments and would lose its image_url block on + # accidental flatten. + sys = SystemMessage( + content=[ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + ) + user = HumanMessage( + content=[ + {"type": "text", "text": "look at this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + ] + ) + request = _make_request(sys) + request.messages = [user] + + flattened = _flattened_request(request) + assert flattened is not None + # System flattened to string … + assert isinstance(flattened.system_message.content, str) + # … user message is untouched (the helper does not even look at it). + assert flattened.messages == [user] + assert isinstance(user.content, list) + assert len(user.content) == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py index 5b3a03581..4cf53969d 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -1,4 +1,4 @@ -"""Tests for ``apply_litellm_prompt_caching`` in +r"""Tests for ``apply_litellm_prompt_caching`` in :mod:`app.agents.new_chat.prompt_caching`. The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which @@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to ``litellm.completion(...)``. The tests below pin its public contract: -1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so +1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so savings compound across multi-turn conversations on Anthropic-family - providers. + providers. ``index: 0`` is used (rather than ``role: system``) because + the deepagent stack accumulates multiple ``SystemMessage``\ s in + ``state["messages"]`` and ``role: system`` would tag every one of + them, blowing past Anthropic's 4-block ``cache_control`` cap. 2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic prompt-cache surface is available). @@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None: apply_litellm_prompt_caching(llm) points = llm.model_kwargs["cache_control_injection_points"] - assert {"location": "message", "role": "system"} in points + assert {"location": "message", "index": 0} in points assert {"location": "message", "index": -1} in points assert len(points) == 2 +def test_does_not_inject_role_system_breakpoint() -> None: + """Regression: deliberately AVOID ``role: system`` so we don't tag + every SystemMessage the deepagent ``before_agent`` injectors push + into ``state["messages"]`` (priority, tree, memory, file-intent, + anonymous-doc). Tagging all of them overflows Anthropic's 4-block + ``cache_control`` cap and surfaces as + ``OpenrouterException: A maximum of 4 blocks with cache_control may + be provided. Found N`` 400s. + """ + llm = _FakeLLM() + apply_litellm_prompt_caching(llm) + points = llm.model_kwargs["cache_control_injection_points"] + assert all(p.get("role") != "system" for p in points), ( + f"Expected no role=system breakpoint, got: {points}" + ) + + def test_injection_points_set_for_anthropic_config() -> None: """Anthropic-family configs need the marker — verify it lands.""" cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet") diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 2ca470680..2933a0504 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -475,3 +475,190 @@ class TestKBSearchPlanSchema: ) ) assert plan.is_recency_query is False + + +# ── mentioned_document_ids cross-turn drain ──────────────────────────── + + +class TestKnowledgePriorityMentionDrain: + """Regression tests for the cross-turn ``mentioned_document_ids`` drain. + + The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware` + instance across turns of the same thread. ``mentioned_document_ids`` + can therefore enter the middleware via two paths: + + 1. The constructor closure (``__init__(mentioned_document_ids=...)``) — + seeded by the cache-miss build on turn 1. + 2. ``runtime.context.mentioned_document_ids`` — supplied freshly per + turn by the streaming task. + + Without the drain fix, an empty ``runtime.context.mentioned_document_ids`` + on turn 2 would fall through to the closure (because ``[]`` is falsy in + Python) and replay turn 1's mentions. This class pins down the + correct behaviour: the runtime path is authoritative even when empty, + and the closure is drained the first time the runtime path fires so + no later turn can ever resurrect stale state. + """ + + @staticmethod + def _make_runtime(mention_ids: list[int]): + """Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``.""" + from types import SimpleNamespace + + return SimpleNamespace( + context=SimpleNamespace(mentioned_document_ids=mention_ids), + ) + + @staticmethod + def _planner_llm() -> "FakeLLM": + # Planner returns a stable, non-recency plan so we always land in + # the hybrid-search branch (where ``fetch_mentioned_documents`` is + # invoked alongside the main search). + return FakeLLM( + json.dumps( + { + "optimized_query": "follow up question", + "start_date": None, + "end_date": None, + "is_recency_query": False, + } + ) + ) + + async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch): + """Turn 1 with mentions in BOTH closure and runtime context: the + runtime path wins AND the closure is drained so a future turn + cannot replay it. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what is in those docs?")]}, + runtime=self._make_runtime([1, 2, 3]), + ) + + assert fetched_ids == [[1, 2, 3]], ( + "runtime.context mentions must be the source of truth on turn 1" + ) + assert middleware.mentioned_document_ids == [], ( + "closure must be drained the first time the runtime path fires " + "so no later turn can replay stale mentions" + ) + + async def test_empty_runtime_context_does_not_replay_closure_mentions( + self, monkeypatch + ): + """Regression: turn 2 with NO mentions must not surface turn 1's + mentions from the constructor closure. + + Before the fix, ``if ctx_mentions:`` treated an empty list as + absent and fell through to ``elif self.mentioned_document_ids:``, + replaying turn 1's mentions. This test pins down the corrected + behaviour. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + # Simulate a cached middleware instance whose closure was seeded + # by a previous turn's cache-miss build (mentions=[1,2,3]). + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[1, 2, 3], + ) + + # Turn 2: streaming task supplies an EMPTY mention list (no + # mentions on this follow-up turn). + await middleware.abefore_agent( + {"messages": [HumanMessage(content="what about the next steps?")]}, + runtime=self._make_runtime([]), + ) + + assert fetched_ids == [], ( + "fetch_mentioned_documents must NOT be called when the runtime " + "context says there are no mentions for this turn" + ) + + async def test_legacy_path_fires_only_when_runtime_context_absent( + self, monkeypatch + ): + """Backward-compat: if a caller doesn't supply runtime.context (old + non-streaming code path), the closure-injected mentions are still + honoured exactly once and then drained. + """ + fetched_ids: list[list[int]] = [] + + async def fake_fetch_mentioned_documents(*, document_ids, search_space_id): + fetched_ids.append(list(document_ids)) + return [] + + async def fake_search_knowledge_base(**_kwargs): + return [] + + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents", + fake_fetch_mentioned_documents, + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", + fake_search_knowledge_base, + ) + + middleware = KnowledgeBaseSearchMiddleware( + llm=self._planner_llm(), + search_space_id=42, + mentioned_document_ids=[7, 8], + ) + + # First call: no runtime → legacy path uses the closure. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="initial question")]}, + runtime=None, + ) + # Second call: still no runtime — closure already drained, so no replay. + await middleware.abefore_agent( + {"messages": [HumanMessage(content="follow up")]}, + runtime=None, + ) + + assert fetched_ids == [[7, 8]], ( + "legacy path must honour the closure exactly once and then drain it" + ) + assert middleware.mentioned_document_ids == [] diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index cc8157464..64e4d5157 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model(): await _preflight_llm(fake_llm) +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_swallows_exceptions(): + """``_settle_speculative_agent_build`` MUST always return cleanly so the + caller can safely re-touch the request-scoped session afterwards. + + The helper guards the parallel preflight + agent-build path: when the + speculative build is being discarded (429 or non-429 preflight failure) + we await it solely to release any in-flight ``AsyncSession`` usage — + the build's outcome is irrelevant. Any exception (including + ``CancelledError``) leaking out would skip the caller's recovery flow + and re-introduce the very session-concurrency hazard the helper exists + to prevent. + """ + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _raises() -> None: + raise RuntimeError("speculative build crashed") + + async def _succeeds() -> str: + return "agent" + + async def _slow() -> None: + await asyncio.sleep(0.05) + + for coro in (_raises(), _succeeds(), _slow()): + task = asyncio.create_task(coro) + await _settle_speculative_agent_build(task) + assert task.done() + + +@pytest.mark.asyncio +async def test_settle_speculative_agent_build_handles_already_done_task(): + """Done tasks (success or failure) must still be settled without raising.""" + import asyncio + + from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build + + async def _ok() -> str: + return "ok" + + async def _bad() -> None: + raise ValueError("nope") + + ok_task = asyncio.create_task(_ok()) + bad_task = asyncio.create_task(_bad()) + # Drive both to completion before settling. + await asyncio.sleep(0) + await asyncio.sleep(0) + + await _settle_speculative_agent_build(ok_task) + await _settle_speculative_agent_build(bad_task) + assert ok_task.result() == "ok" + # ``bad_task`` exception was consumed by the settle helper; calling + # ``.exception()`` after the fact must still return the original error + # (the helper observes it but doesn't clear it). + assert isinstance(bad_task.exception(), ValueError) + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 07c11b4d6..7616d461d 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -254,8 +254,8 @@ function PricingFAQ() { Frequently Asked Questions

- Everything you need to know about SurfSense pages, premium credits, and billing. Can't - find what you need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credits, and billing. + Can't find what you need? Reach out at{" "} rohan@surfsense.com