Merge remote-tracking branch 'upstream/dev' into feat/ui-revamp

This commit is contained in:
Anish Sarkar 2026-05-03 18:58:55 +05:30
commit 4e8c552440
142 changed files with 14603 additions and 6056 deletions

View file

@ -1 +1 @@
0.0.19 0.0.20

View file

@ -308,6 +308,24 @@ STT_SERVICE=local/base
# Advanced (optional) # Advanced (optional)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# New-chat agent feature flags
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_BUSY_MUTEX=true
SURFSENSE_ENABLE_SKILLS=true
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
SURFSENSE_ENABLE_ACTION_LOG=true
SURFSENSE_ENABLE_REVERT_ROUTE=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
# Periodic connector sync interval (default: 5m) # Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m # SCHEDULE_CHECKER_INTERVAL=5m

View file

@ -324,3 +324,30 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_PLUGIN_LOADER=false # SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names # Comma-separated allowlist of plugin entry-point names
# SURFSENSE_ALLOWED_PLUGINS=year_substituter # SURFSENSE_ALLOWED_PLUGINS=year_substituter
# -----------------------------------------------------------------------------
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
# -----------------------------------------------------------------------------
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
# on a cold turn) is reused across subsequent turns on the same thread,
# collapsing it to a microsecond hash lookup. All connector tools acquire
# their own short-lived DB session per call (Phase 2 refactor) so a cached
# closure is safe to share across requests. Flip OFF only as a last-resort
# rollback if you suspect cache-related staleness.
# SURFSENSE_ENABLE_AGENT_CACHE=true
# Cache capacity (max number of compiled-agent entries kept in memory)
# and TTL per entry (seconds). Working set is typically one entry per
# active thread on this replica; tune up for very large deployments.
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
# -----------------------------------------------------------------------------
# Connector discovery TTL cache (Phase 1.4 perf optimization)
# -----------------------------------------------------------------------------
# Caches the per-search-space "available connectors" + "available document
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
# ORM event listeners auto-invalidate on connector / document inserts,
# updates and deletes — the TTL only bounds staleness for bulk-import
# paths that bypass the ORM. Set to 0 to disable the cache.
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30

View file

@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
COPY pyproject.toml . COPY pyproject.toml .
COPY uv.lock . COPY uv.lock .
# Install PyTorch based on architecture # Install all Python dependencies from uv.lock for deterministic builds.
RUN if [ "$(uname -m)" = "x86_64" ]; then \ #
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \ # `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
else \ # which lets prod silently drift to newer upstream versions on every rebuild
pip install --no-cache-dir torch torchvision torchaudio; \ # (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
fi # Exporting the lock to requirements.txt and feeding it to `uv pip install`
# pins every transitive package to the exact version captured in uv.lock.
# Install python dependencies #
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
# downloads that the lock-based install immediately replaced. If a specific
# CUDA version is needed (driver compatibility, etc.), wire it through
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
RUN pip install --no-cache-dir uv && \ RUN pip install --no-cache-dir uv && \
uv pip install --system --no-cache-dir -e . uv export --frozen --no-dev --no-hashes --no-emit-project \
--format requirements-txt -o /tmp/requirements.txt && \
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
rm /tmp/requirements.txt
# Set SSL environment variables dynamically # Set SSL environment variables dynamically
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \ RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
# Pre-download Docling models # Pre-download Docling models
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
# Install Playwright browsers for web scraping if needed # Install Playwright browsers for web scraping (the playwright package itself
RUN pip install playwright && \ # is already installed via uv.lock above)
playwright install chromium --with-deps RUN playwright install chromium --with-deps
# Copy source code # Copy source code
COPY . . COPY . .
# Install the project itself in editable mode. Dependencies were already
# installed deterministically from uv.lock above, so --no-deps prevents any
# re-resolution that could pull newer versions.
RUN uv pip install --system --no-cache-dir --no-deps -e .
# Copy and set permissions for entrypoint script # Copy and set permissions for entrypoint script
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts) # Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh

View file

@ -0,0 +1,357 @@
"""TTL-LRU cache for compiled SurfSense deep agents.
Why this exists
---------------
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
turn:
1. Discover connectors & document types from Postgres (~50-200ms)
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
3. Compose the system prompt
4. Construct ~15 middleware instances (CPU)
5. Eagerly compile the general-purpose subagent
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
which builds a second LangGraph + Pydantic schemas ~1.5-2s of pure
CPU work)
6. Compile the outer LangGraph
For a single thread, all six steps produce the SAME object on every turn
unless the user has changed their LLM config, toggled a feature flag,
added a connector, etc. The right answer is to compile ONCE per
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
every subsequent turn on the same thread.
Why a per-thread key (not a global pool)
----------------------------------------
Most middleware in the SurfSense stack captures per-thread state in
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
would silently leak state across users and threads. Keying the cache on
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
turns on the same thread without changing any middleware's behavior.
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
(read via ``runtime.context``) so the cache can collapse to a single
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
then, per-thread keying is the only safe option.
Cache shape
-----------
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
minutes matches a typical chat session). ``maxsize`` (default 256)
caps memory; LRU evicts least-recently-used on overflow.
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
cold misses on the same key wait for the first build instead of
building N times.
* Process-local: this is an in-memory cache. Multi-replica deployments
pay the build cost once per replica per key. That's fine; the working
set per replica is small (one entry per active thread on that replica).
Telemetry
---------
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
* ``hit`` cache hit, microseconds-fast
* ``miss`` first build for this key, includes build duration
* ``stale`` entry was found but expired; rebuilt
* ``evict`` LRU eviction (size-limited)
* ``size`` current cache occupancy at lookup time
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# ---------------------------------------------------------------------------
# Public API: signature helpers (cache key components)
# ---------------------------------------------------------------------------
def stable_hash(*parts: Any) -> str:
"""Compute a deterministic SHA1 of the str repr of ``parts``.
Used for cache key components that need a fixed-width representation
(system prompt, tool list, etc.). SHA1 is fine here this is not a
security boundary, just a content fingerprint.
"""
h = hashlib.sha1(usedforsecurity=False)
for p in parts:
h.update(repr(p).encode("utf-8", errors="replace"))
h.update(b"\x1f") # ASCII unit separator between parts
return h.hexdigest()
def tools_signature(
tools: list[Any] | tuple[Any, ...],
*,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> str:
"""Hash the bound-tool surface for cache-key purposes.
The signature changes whenever:
* A tool is added or removed from the bound list (built-in toggles,
MCP tools loaded for the user changes, gating rules flip, etc.).
* The available connectors / document types for the search space
change (new connector added, last connector removed, new document
type indexed). Because :func:`get_connector_gated_tools` derives
``modified_disabled_tools`` from ``available_connectors``, the
tool surface is technically already covered but we hash the
connector list separately so an empty-list "no tools changed"
situation still rotates the key when, say, the user re-adds a
connector that gates a tool we were already not exposing.
Stays stable across:
* Process restarts (tool names + descriptions are static).
* Different replicas (everyone gets the same hash for the same
inputs).
"""
tool_descriptors = sorted(
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
)
connectors = sorted(available_connectors or [])
doc_types = sorted(available_document_types or [])
return stable_hash(tool_descriptors, connectors, doc_types)
def flags_signature(flags: Any) -> str:
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
Frozen dataclasses are deterministically reprable, so a SHA1 of their
repr is a stable fingerprint. Restart safe (flags are read once at
process boot).
"""
return stable_hash(repr(flags))
def system_prompt_hash(system_prompt: str) -> str:
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
return hashlib.sha1(
system_prompt.encode("utf-8", errors="replace"),
usedforsecurity=False,
).hexdigest()
# ---------------------------------------------------------------------------
# Cache implementation
# ---------------------------------------------------------------------------
@dataclass
class _Entry:
value: Any
created_at: float
last_used_at: float
class _AgentCache:
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
NOT THREAD-SAFE in the multithreading sense designed for a single
asyncio event loop. Uvicorn runs one event loop per worker process,
so this is fine; multi-worker deployments simply each maintain their
own cache.
"""
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
self._maxsize = maxsize
self._ttl = ttl_seconds
self._entries: OrderedDict[str, _Entry] = OrderedDict()
# One lock per key — guards "build" so concurrent cold misses on
# the same key wait for the first build instead of all racing.
self._locks: dict[str, asyncio.Lock] = {}
def _now(self) -> float:
return time.monotonic()
def _is_fresh(self, entry: _Entry) -> bool:
return (self._now() - entry.created_at) < self._ttl
def _evict_if_full(self) -> None:
while len(self._entries) >= self._maxsize:
evicted_key, _ = self._entries.popitem(last=False)
self._locks.pop(evicted_key, None)
_perf_log.info(
"[agent_cache] evict key=%s reason=lru size=%d",
_short(evicted_key),
len(self._entries),
)
def _touch(self, key: str, entry: _Entry) -> None:
entry.last_used_at = self._now()
self._entries.move_to_end(key, last=True)
async def get_or_build(
self,
key: str,
*,
builder: Callable[[], Awaitable[Any]],
) -> Any:
"""Return the cached value for ``key`` or call ``builder()`` to make it.
``builder`` MUST be idempotent concurrent cold misses on the
same key collapse to a single ``builder()`` call (the others
wait on the in-flight lock and observe the populated entry on
wake).
"""
# Fast path: hot hit.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
# Stale entry — drop it; rebuild below.
if entry is not None and not self._is_fresh(entry):
_perf_log.info(
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
_short(key),
self._now() - entry.created_at,
self._ttl,
)
self._entries.pop(key, None)
# Slow path: serialize concurrent misses for the same key.
lock = self._locks.setdefault(key, asyncio.Lock())
async with lock:
# Double-check after acquiring the lock — another waiter may
# have populated the entry while we slept.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
t0 = time.perf_counter()
try:
value = await builder()
except BaseException:
# Don't cache failed builds; let the next caller retry.
_perf_log.warning(
"[agent_cache] build_failed key=%s elapsed=%.3fs",
_short(key),
time.perf_counter() - t0,
)
raise
elapsed = time.perf_counter() - t0
# Insert + evict.
self._evict_if_full()
now = self._now()
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
self._entries.move_to_end(key, last=True)
_perf_log.info(
"[agent_cache] miss key=%s build=%.3fs size=%d",
_short(key),
elapsed,
len(self._entries),
)
return value
def invalidate(self, key: str) -> bool:
"""Drop a single entry; return True if anything was removed."""
removed = self._entries.pop(key, None) is not None
self._locks.pop(key, None)
if removed:
_perf_log.info(
"[agent_cache] invalidate key=%s size=%d",
_short(key),
len(self._entries),
)
return removed
def invalidate_prefix(self, prefix: str) -> int:
"""Drop every entry whose key starts with ``prefix``. Returns count."""
keys = [k for k in self._entries if k.startswith(prefix)]
for k in keys:
self._entries.pop(k, None)
self._locks.pop(k, None)
if keys:
_perf_log.info(
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
_short(prefix),
len(keys),
len(self._entries),
)
return len(keys)
def clear(self) -> None:
n = len(self._entries)
self._entries.clear()
self._locks.clear()
if n:
_perf_log.info("[agent_cache] clear removed=%d", n)
def stats(self) -> dict[str, Any]:
return {
"size": len(self._entries),
"maxsize": self._maxsize,
"ttl_seconds": self._ttl,
}
def _short(key: str, n: int = 16) -> str:
"""Truncate keys for log lines so they don't blow up log volume."""
return key if len(key) <= n else f"{key[:n]}..."
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
def get_cache() -> _AgentCache:
"""Return the process-wide compiled-agent cache singleton."""
return _cache
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
"""Replace the singleton with a fresh cache. Tests only."""
global _cache
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
return _cache
__all__ = [
"flags_signature",
"get_cache",
"reload_for_tests",
"stable_hash",
"system_prompt_hash",
"tools_signature",
]

View file

@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_backends import build_backend_resolver
@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware, DedupHITLToolCallsMiddleware,
DoomLoopMiddleware, DoomLoopMiddleware,
FileIntentMiddleware, FileIntentMiddleware,
FlattenSystemMessageMiddleware,
KnowledgeBasePersistenceMiddleware, KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware, KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware, KnowledgeTreeMiddleware,
@ -330,23 +338,39 @@ async def create_surfsense_deep_agent(
else None, else None,
) )
# Discover available connectors and document types for this search space # Discover available connectors and document types for this search space.
#
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
# SQLAlchemy explicitly forbids concurrent operations on the same session
# ("This session is provisioning a new connection; concurrent operations
# are not permitted on the same session"). The Phase 1.4 in-process TTL
# cache in ``connector_service`` already collapses the warm path to a
# near-zero pair of dict lookups, so sequential awaits cost nothing in
# the common case while remaining correct on cold cache misses.
available_connectors: list[str] | None = None available_connectors: list[str] | None = None
available_document_types: list[str] | None = None available_document_types: list[str] | None = None
_t0 = time.perf_counter() _t0 = time.perf_counter()
try: try:
connector_types = await connector_service.get_available_connectors( try:
connector_types_result = await connector_service.get_available_connectors(
search_space_id search_space_id
) )
if connector_types: if connector_types_result:
available_connectors = _map_connectors_to_searchable_types(connector_types) available_connectors = _map_connectors_to_searchable_types(
connector_types_result
available_document_types = await connector_service.get_available_document_types(
search_space_id
) )
except Exception as e: except Exception as e:
logging.warning("Failed to discover available connectors: %s", e)
try:
available_document_types = (
await connector_service.get_available_document_types(search_space_id)
)
except Exception as e:
logging.warning("Failed to discover available document types: %s", e)
except Exception as e: # pragma: no cover - defensive outer guard
logging.warning(f"Failed to discover available connectors/document types: {e}") logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info( _perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs", "[create_agent] Connector/doc-type discovery in %.3fs",
@ -469,8 +493,16 @@ async def create_surfsense_deep_agent(
# entire middleware build + main-graph compile into a single # entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive. # event loop stays responsive.
_t0 = time.perf_counter() #
agent = await asyncio.to_thread( # PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
# on every per-request value that any middleware in the stack closes
# over in ``__init__`` — drop one and you risk leaking state across
# threads. Hits collapse this whole block to a microsecond lookup;
# misses pay the original CPU cost AND populate the cache.
config_id = agent_config.config_id if agent_config is not None else None
async def _build_agent() -> Any:
return await asyncio.to_thread(
_build_compiled_agent_blocking, _build_compiled_agent_blocking,
llm=llm, llm=llm,
tools=tools, tools=tools,
@ -484,14 +516,54 @@ async def create_surfsense_deep_agent(
anon_session_id=anon_session_id, anon_session_id=anon_session_id,
available_connectors=available_connectors, available_connectors=available_connectors,
available_document_types=available_document_types, available_document_types=available_document_types,
# ``mentioned_document_ids`` is consumed by
# ``KnowledgePriorityMiddleware`` per turn via
# ``runtime.context`` (Phase 1.5). We still pass the
# caller-provided list here for the legacy fallback path
# (cache disabled / context not propagated) — the middleware
# drains its own copy after the first read so a cached graph
# never replays stale mentions.
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
max_input_tokens=_max_input_tokens, max_input_tokens=_max_input_tokens,
flags=_flags, flags=_flags,
checkpointer=checkpointer, checkpointer=checkpointer,
) )
_t0 = time.perf_counter()
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
# Cache key components — order matters only for human readability;
# the resulting hash is what's stored. Every component must
# rotate on a real shape change AND stay stable across identical
# invocations.
cache_key = stable_hash(
"v1", # schema version of the key — bump if components change
config_id,
thread_id,
user_id,
search_space_id,
visibility,
filesystem_selection.mode,
anon_session_id,
tools_signature(
tools,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
flags_signature(_flags),
system_prompt_hash(final_system_prompt),
_max_input_tokens,
# ``mentioned_document_ids`` deliberately omitted — middleware
# reads it from ``runtime.context`` (Phase 1.5).
)
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
else:
agent = await _build_agent()
_perf_log.info( _perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs", "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
"on"
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
else "off",
) )
_perf_log.info( _perf_log.info(
@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking(
noop_mw, noop_mw,
retry_mw, retry_mw,
fallback_mw, fallback_mw,
# Coalesce a multi-text-block system message into one block
# immediately before the model call. Sits innermost on the
# system-message-mutation chain so it observes every appender
# (todo / filesystem / skills / subagents …) and prevents
# OpenRouter→Anthropic from redistributing ``cache_control``
# across N blocks and tripping Anthropic's 4-breakpoint cap.
# See ``middleware/flatten_system.py`` for full rationale.
FlattenSystemMessageMiddleware(),
# Tool-call repair must run after model emits but before # Tool-call repair must run after model emits but before
# permission / dedup / doom-loop interpret the calls. # permission / dedup / doom-loop interpret the calls.
repair_mw, repair_mw,

View file

@ -1,10 +1,25 @@
""" """
Context schema definitions for SurfSense agents. Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent. This module defines the per-invocation context object passed to the SurfSense
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
The agent's compiled graph is the same across invocations (and cached by
``agent_cache``), so anything that varies per turn the user mentions a
specific document, the front-end issues a unique ``request_id``, etc.
MUST live on this context object instead of being captured into a
middleware ``__init__`` closure. Middlewares read fields back via
``runtime.context.<field>``; tools read them via ``runtime.context``.
This object is read inside both ``KnowledgePriorityMiddleware`` (for
``mentioned_document_ids``) and any future middleware that needs
per-request state without invalidating the compiled-agent cache.
""" """
from typing import NotRequired, TypedDict from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
class FileOperationContractState(TypedDict): class FileOperationContractState(TypedDict):
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
turn_id: str turn_id: str
class SurfSenseContextSchema(TypedDict): @dataclass
class SurfSenseContextSchema:
""" """
Custom state schema for the SurfSense deep agent. Per-invocation context for the SurfSense deep agent.
This extends the default agent state with custom fields. Defaults are chosen so the dataclass can be safely default-constructed
The default state already includes: (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
- messages: Conversation history context is supplied see ``langgraph.runtime.Runtime``). All fields
- todos: Task list from TodoListMiddleware are optional; consumers must None-check before reading.
- files: Virtual filesystem from FilesystemMiddleware
We're adding fields needed for knowledge base search: Phase 1.5 fields:
- search_space_id: The user's search space ID search_space_id: Search space the request is scoped to.
- db_session: Database session (injected at runtime) mentioned_document_ids: KB documents the user @-mentioned this turn.
- connector_service: Connector service instance (injected at runtime) Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
file_operation_contract: One-shot file operation contract emitted
by ``FileIntentMiddleware`` for the upcoming turn.
turn_id / request_id: Correlation IDs surfaced by the streaming
task; populated for telemetry.
Phase 2 will extend with: thread_id, user_id, visibility,
filesystem_mode, anon_session_id, available_connectors,
available_document_types, created_by_id (everything currently captured
by middleware ``__init__`` closures).
""" """
search_space_id: int search_space_id: int | None = None
file_operation_contract: NotRequired[FileOperationContractState] mentioned_document_ids: list[int] = field(default_factory=list)
turn_id: NotRequired[str] file_operation_contract: FileOperationContractState | None = None
request_id: NotRequired[str] turn_id: str | None = None
# These are runtime-injected and won't be serialized request_id: str | None = None
# db_session and connector_service are passed when invoking the agent

View file

@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode, These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
SurfSense-native). They follow a "default-OFF for risky things, SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
default-ON for safe upgrades, master kill-switch for everything new" model. image updates work even when older installs do not have newly introduced
environment variables. Risky/experimental integrations stay default OFF,
and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
Examples Examples
-------- --------
Local development (recommended for trying everything except doom-loop / selector): Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else): Master kill-switch (overrides everything else):
@ -60,32 +65,28 @@ class AgentFeatureFlags:
disable_new_agent_stack: bool = False disable_new_agent_stack: bool = False
# Agent quality — context budget, retry/limits, name-repair, doom-loop # Agent quality — context budget, retry/limits, name-repair, doom-loop
enable_context_editing: bool = False enable_context_editing: bool = True
enable_compaction_v2: bool = False enable_compaction_v2: bool = True
enable_retry_after: bool = False enable_retry_after: bool = True
enable_model_fallback: bool = False enable_model_fallback: bool = False
enable_model_call_limit: bool = False enable_model_call_limit: bool = True
enable_tool_call_limit: bool = False enable_tool_call_limit: bool = True
enable_tool_call_repair: bool = False enable_tool_call_repair: bool = True
enable_doom_loop: bool = ( enable_doom_loop: bool = True
False # Default OFF until UI handles permission='doom_loop'
)
# Safety — permissions, concurrency, tool-set narrowing # Safety — permissions, concurrency, tool-set narrowing
enable_permission: bool = False # Default OFF for first deploy enable_permission: bool = True
enable_busy_mutex: bool = False enable_busy_mutex: bool = True
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
# Skills + subagents # Skills + subagents
enable_skills: bool = False enable_skills: bool = True
enable_specialized_subagents: bool = False enable_specialized_subagents: bool = True
enable_kb_planner_runnable: bool = False enable_kb_planner_runnable: bool = True
# Snapshot / revert # Snapshot / revert
enable_action_log: bool = False enable_action_log: bool = True
enable_revert_route: bool = ( enable_revert_route: bool = True
False # Backend ships before UI; route returns 503 until this flips
)
# Streaming parity v2 — opt in to LangChain's structured # Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input # ``AIMessageChunk`` content (typed reasoning blocks, tool-input
@ -94,7 +95,7 @@ class AgentFeatureFlags:
# text path and the synthetic ``call_<run_id>`` tool-call id (no # text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136 # ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible. # ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = False enable_stream_parity_v2: bool = True
# Plugins # Plugins
enable_plugin_loader: bool = False enable_plugin_loader: bool = False
@ -102,6 +103,41 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False enable_otel: bool = False
# Performance — compiled-agent cache (Phase 1 + Phase 2).
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
# graph if the cache key matches (LLM config + thread + tool surface +
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
# wall clock from ~4-5s to <50µs on cache hits.
#
# SAFETY (Phase 2 unblocked this default-on):
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
# short-lived ``AsyncSession`` instances per call via
# :data:`async_session_maker`. The factory still accepts ``db_session``
# for registry compatibility but ``del``'s it immediately — see any
# of those files' factory docstrings for the rationale. The ``llm``
# closure is per-(provider, model, config_id) which is already in
# the cache key, so the LLM is safe to share across cached hits of
# the same key. The KB priority middleware reads
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
# not its constructor closure, so the same compiled agent serves
# turns with different mention lists correctly.
#
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
# environment if a regression surfaces. The path is exercised by
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
enable_agent_cache: bool = True
# Phase 1 (deferred — measure first): pre-build & share the
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
# misses. Only helps when the outer cache MISSES (cache hits already
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
# until we have data showing cold misses are frequent enough to
# justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False
@classmethod @classmethod
def from_env(cls) -> AgentFeatureFlags: def from_env(cls) -> AgentFeatureFlags:
"""Read flags from environment. """Read flags from environment.
@ -115,48 +151,76 @@ class AgentFeatureFlags:
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
"middleware is forced OFF for this build." "middleware is forced OFF for this build."
) )
return cls(disable_new_agent_stack=True) return cls(
disable_new_agent_stack=True,
enable_context_editing=False,
enable_compaction_v2=False,
enable_retry_after=False,
enable_model_fallback=False,
enable_model_call_limit=False,
enable_tool_call_limit=False,
enable_tool_call_repair=False,
enable_doom_loop=False,
enable_permission=False,
enable_busy_mutex=False,
enable_llm_tool_selector=False,
enable_skills=False,
enable_specialized_subagents=False,
enable_kb_planner_runnable=False,
enable_action_log=False,
enable_revert_route=False,
enable_stream_parity_v2=False,
enable_plugin_loader=False,
enable_otel=False,
enable_agent_cache=False,
enable_agent_cache_share_gp_subagent=False,
)
return cls( return cls(
disable_new_agent_stack=False, disable_new_agent_stack=False,
# Agent quality # Agent quality
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool( enable_model_call_limit=_env_bool(
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
), ),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
enable_tool_call_repair=_env_bool( enable_tool_call_repair=_env_bool(
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
), ),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
# Safety # Safety
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
enable_llm_tool_selector=_env_bool( enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
), ),
# Skills + subagents # Skills + subagents
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
enable_specialized_subagents=_env_bool( enable_specialized_subagents=_env_bool(
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
), ),
enable_kb_planner_runnable=_env_bool( enable_kb_planner_runnable=_env_bool(
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
), ),
# Snapshot / revert # Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2 # Streaming parity v2
enable_stream_parity_v2=_env_bool( enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False "SURFSENSE_ENABLE_STREAM_PARITY_V2", True
), ),
# Plugins # Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability # Observability
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
# Performance
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
enable_agent_cache_share_gp_subagent=_env_bool(
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
),
) )
def any_new_middleware_enabled(self) -> bool: def any_new_middleware_enabled(self) -> bool:

View file

@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk yield chunk
# Provider mapping for LiteLLM model string construction # Provider mapping for LiteLLM model string construction.
PROVIDER_MAP = { #
"OPENAI": "openai", # Single source of truth lives in
"ANTHROPIC": "anthropic", # :mod:`app.services.provider_capabilities` so the YAML loader (which
"GROQ": "groq", # runs during ``app.config`` class-body init) can resolve provider
"COHERE": "cohere", # prefixes without dragging the agent / tools tree into module load
"GOOGLE": "gemini", # order. Re-exported here under the historical ``PROVIDER_MAP`` name
"OLLAMA": "ollama_chat", # so existing callers (``llm_router_service``, ``image_gen_router_service``,
"MISTRAL": "mistral", # tests) keep working unchanged.
"AZURE_OPENAI": "azure", from app.services.provider_capabilities import ( # noqa: E402
"OPENROUTER": "openrouter", _PROVIDER_PREFIX_MAP as PROVIDER_MAP,
"XAI": "xai", )
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@ -178,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog.
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
# which prefers OpenRouter's ``architecture.input_modalities`` and
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
supports_image_input: bool = True
@classmethod @classmethod
def from_auto_mode(cls) -> "AgentConfig": def from_auto_mode(cls) -> "AgentConfig":
""" """
@ -203,6 +191,12 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually
# contains at least one vision-capable deployment; the router
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
supports_image_input=True,
) )
@classmethod @classmethod
@ -216,10 +210,24 @@ class AgentConfig:
Returns: Returns:
AgentConfig instance AgentConfig instance
""" """
return cls( # Lazy import to avoid pulling provider_capabilities (and its
provider=config.provider.value # transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
config.provider.value
if hasattr(config.provider, "value") if hasattr(config.provider, "value")
else str(config.provider), else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
return cls(
provider=provider_value,
model_name=config.model_name, model_name=config.model_name,
api_key=config.api_key, api_key=config.api_key,
api_base=config.api_base, api_base=config.api_base,
@ -235,6 +243,16 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we
# ask LiteLLM (default-allow on unknown). The streaming
# safety net still blocks if the model is *explicitly*
# marked text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
),
) )
@classmethod @classmethod
@ -253,15 +271,46 @@ class AgentConfig:
Returns: Returns:
AgentConfig instance AgentConfig instance
""" """
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string # Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "") system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
# Explicit YAML override wins; otherwise derive from LiteLLM /
# OpenRouter modalities. The YAML loader already populates this
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
)
return cls( return cls(
provider=yaml_config.get("provider", "").upper(), provider=provider,
model_name=yaml_config.get("model_name", ""), model_name=model_name,
api_key=yaml_config.get("api_key", ""), api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"), api_base=yaml_config.get("api_base"),
custom_provider=yaml_config.get("custom_provider"), custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"), litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility) # Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None, system_instructions=system_instructions if system_instructions else None,
@ -276,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium", is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False), anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
supports_image_input=supports_image_input,
) )

View file

@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import ( from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware, SurfSenseFilesystemMiddleware,
) )
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
)
from app.agents.new_chat.middleware.kb_persistence import ( from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware, KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state, commit_staged_filesystem_state,
@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware", "DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware", "DoomLoopMiddleware",
"FileIntentMiddleware", "FileIntentMiddleware",
"FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware", "KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware", "KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware", "KnowledgePriorityMiddleware",

View file

@ -0,0 +1,233 @@
r"""Coalesce multi-block system messages into a single text block.
Several middlewares in our deepagent stack each call
``append_to_system_message`` on the way down to the model
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
``SkillsMiddleware``, ``SubAgentMiddleware`` ). By the time the
request reaches the LLM, the system message has 5+ separate text blocks.
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
request**, and we configure 2 injection points
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
the prepended ``request.system_message``, this middleware is the
defensive partner: it guarantees that "the system block" is *one*
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
OpenRouterAnthropic transformer can never multiply our budget into
several breakpoints by spreading ``cache_control`` across multiple
text blocks of a multi-block system content.
Without flattening we used to see::
OpenrouterException - {"error":{"message":"Provider returned error",
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
cache_control may be provided. Found 5."}}}
(Same error class documented in
https://github.com/BerriAI/litellm/issues/15696 and
https://github.com/BerriAI/litellm/issues/20485 the litellm-side fix
in PR #15395 covers the litellm transformer but does not protect us
when the OpenRouter SaaS itself does the redistribution.)
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
the first injection point from ``role: system`` to ``index: 0``)
neutralises the *primary* cause of the same 400 multiple
``SystemMessage``\ s injected by ``before_agent`` middlewares
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
turns, each tagged with ``cache_control`` by the ``role: system``
matcher. This middleware remains useful as defence-in-depth against
the multi-block redistribution path.
Placement: innermost on the system-message-mutation chain, after every
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
summarization, but before ``noop``/``retry``/``fallback`` so each retry
attempt sees a flattened payload. See ``chat_deepagent.py``.
Idempotent: a string-content system message is left untouched. A list
that contains anything other than plain text blocks (e.g. an image) is
also left untouched those are rare on system messages and we'd lose
the non-text payload by joining.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import SystemMessage
logger = logging.getLogger(__name__)
def _flatten_text_blocks(content: list[Any]) -> str | None:
"""Return joined text if every block is a plain ``{"type": "text"}``.
Returns ``None`` when the list contains anything that isn't a text
block we can safely concatenate (image, audio, file, non-standard
blocks, dicts with extra non-cache_control fields). The caller
leaves the original content untouched in that case rather than
silently dropping payload.
``cache_control`` on individual blocks is intentionally discarded
the whole point of flattening is to let LiteLLM's
``cache_control_injection_points`` re-place a single breakpoint on
the resulting one-block system content.
"""
chunks: list[str] = []
for block in content:
if isinstance(block, str):
chunks.append(block)
continue
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
chunks.append(text)
return "\n\n".join(chunks)
def _flattened_request(
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT] | None:
"""Return a request with system_message flattened, or ``None`` for no-op."""
sys_msg = request.system_message
if sys_msg is None:
return None
content = sys_msg.content
if not isinstance(content, list) or len(content) <= 1:
return None
flattened = _flatten_text_blocks(content)
if flattened is None:
return None
new_sys = SystemMessage(
content=flattened,
additional_kwargs=dict(sys_msg.additional_kwargs),
response_metadata=dict(sys_msg.response_metadata),
)
if sys_msg.id is not None:
new_sys.id = sys_msg.id
return request.override(system_message=new_sys)
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
"""One-line dump of cache_control-relevant request shape.
Temporary diagnostic to prove where the ``Found N`` cache_control
breakpoints are coming from when Anthropic 400s. Removed once the
root cause is confirmed and a fix is in place.
"""
sys_msg = request.system_message
if sys_msg is None:
sys_shape = "none"
elif isinstance(sys_msg.content, str):
sys_shape = f"str(len={len(sys_msg.content)})"
elif isinstance(sys_msg.content, list):
sys_shape = f"list(blocks={len(sys_msg.content)})"
else:
sys_shape = f"other({type(sys_msg.content).__name__})"
role_hist: list[str] = []
multi_block_msgs = 0
msgs_with_cc = 0
sys_msgs_in_history = 0
for m in request.messages:
mtype = getattr(m, "type", type(m).__name__)
role_hist.append(mtype)
if isinstance(m, SystemMessage):
sys_msgs_in_history += 1
c = getattr(m, "content", None)
if isinstance(c, list):
multi_block_msgs += 1
for blk in c:
if isinstance(blk, dict) and "cache_control" in blk:
msgs_with_cc += 1
break
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
msgs_with_cc += 1
tools = request.tools or []
tools_with_cc = 0
for t in tools:
if isinstance(t, dict) and (
"cache_control" in t or "cache_control" in t.get("function", {})
):
tools_with_cc += 1
return (
f"sys={sys_shape} msgs={len(request.messages)} "
f"sys_msgs_in_history={sys_msgs_in_history} "
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
f"roles={role_hist[-8:]}"
)
class FlattenSystemMessageMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Collapse a multi-text-block system message to a single string.
Sits innermost on the system-message-mutation chain so it observes
every middleware's contribution. Has no other side effect — the
body of every block is preserved, just joined with ``"\\n\\n"``.
"""
def __init__(self) -> None:
super().__init__()
self.tools = []
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return handler(flattened)
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return await handler(flattened)
return await handler(request)
__all__ = [
"FlattenSystemMessageMiddleware",
"_flatten_text_blocks",
"_flattened_request",
]

View file

@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState, state: AgentState,
runtime: Runtime[Any], runtime: Runtime[Any],
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD: if self.filesystem_mode != FilesystemMode.CLOUD:
return None return None
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc: if anon_doc:
return self._anon_priority(state, anon_doc) return self._anon_priority(state, anon_doc)
return await self._authenticated_priority(state, messages, user_text) return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority( def _anon_priority(
self, self,
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState, state: AgentState,
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
user_text: str, user_text: str,
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time() t0 = asyncio.get_event_loop().time()
( (
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text, user_text=user_text,
) )
mentioned_results: list[dict[str, Any]] = [] # Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids: if self.mentioned_document_ids:
self.mentioned_document_ids = []
elif self.mentioned_document_ids:
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
mentioned_results: list[dict[str, Any]] = []
if mention_ids:
mentioned_results = await fetch_mentioned_documents( mentioned_results = await fetch_mentioned_documents(
document_ids=self.mentioned_document_ids, document_ids=mention_ids,
search_space_id=self.search_space_id, search_space_id=self.search_space_id,
) )
self.mentioned_document_ids = []
if is_recency: if is_recency:
doc_types = _resolve_search_types( doc_types = _resolve_search_types(

View file

@ -1,4 +1,4 @@
"""LiteLLM-native prompt caching configuration for SurfSense agents. r"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)`` activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
@ -17,8 +17,20 @@ Coverage:
We inject **two** breakpoints per request: We inject **two** breakpoints per request:
- ``role: system`` pins the SurfSense system prompt (provider variant, - ``index: 0`` pins the SurfSense system prompt at the head of the
citation rules, tool catalog, KB tree, skills metadata) into the cache. request (provider variant, citation rules, tool catalog, KB tree,
skills metadata). The langchain agent factory always prepends
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouterAnthropic.
- ``index: -1`` pins the latest message so multi-turn savings compound: - ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix. N+1 still reads turn N's cache up to the shared prefix.
@ -51,11 +63,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Two-breakpoint policy: system + latest message. See module docstring for # Two-breakpoint policy: head-of-request + latest message. See module
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we # docstring for rationale. Anthropic caps requests at 4 ``cache_control``
# use 2 here, leaving headroom for Phase-2 tool caching. # blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
# anonymous-doc) insert ``SystemMessage`` instances into
# ``state["messages"]`` that accumulate across turns. With
# ``role: system`` the LiteLLM hook would tag *every* one of them with
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
# always targets the langchain-prepended ``request.system_message``
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
# block), giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "role": "system"}, {"location": "message", "index": 0},
{"location": "message", "index": -1}, {"location": "message", "index": -1},
) )

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the create_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_confluence_page( async def create_confluence_page(
title: str, title: str,
@ -42,13 +60,14 @@ def create_create_confluence_page_tool(
""" """
logger.info(f"create_confluence_page called: title='{title}'") logger.info(f"create_confluence_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session) metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
@ -183,7 +202,9 @@ def create_create_confluence_page_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the delete_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_confluence_page( async def delete_confluence_page(
page_title_or_id: str, page_title_or_id: str,
@ -43,13 +61,14 @@ def create_delete_confluence_page_tool(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session) metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_deletion_context( context = await metadata_service.get_deletion_context(
search_space_id, user_id, page_title_or_id search_space_id, user_id, page_title_or_id
@ -95,7 +114,9 @@ def create_delete_confluence_page_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
from sqlalchemy.future import select from sqlalchemy.future import select
@ -135,7 +156,10 @@ def create_delete_confluence_page_tool(
or "status code 403" in str(api_err).lower() or "status code 403" in str(api_err).lower()
): ):
try: try:
connector.config = {**connector.config, "auth_expired": True} connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config") flag_modified(connector, "config")
await db_session.commit() await db_session.commit()
except Exception: except Exception:

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the update_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_confluence_page( async def update_confluence_page(
page_title_or_id: str, page_title_or_id: str,
@ -45,13 +63,14 @@ def create_update_confluence_page_tool(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session) metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, page_title_or_id search_space_id, user_id, page_title_or_id
@ -152,7 +171,10 @@ def create_update_confluence_page_tool(
or "status code 403" in str(api_err).lower() or "status code 403" in str(api_err).lower()
): ):
try: try:
connector.config = {**connector.config, "auth_expired": True} connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config") flag_modified(connector, "config")
await db_session.commit() await db_session.commit()
except Exception: except Exception:

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
) -> StructuredTool: ) -> StructuredTool:
"""Factory function to create the get_connected_accounts tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to scope account discovery to.
user_id: User ID to scope account discovery to.
Returns:
Configured StructuredTool for connected-accounts discovery.
"""
del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]: async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service) svc_cfg = MCP_SERVICES.get(service)
@ -68,6 +85,7 @@ def create_get_connected_accounts_tool(
except ValueError: except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_discord_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_discord_channels tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_discord_channels() -> dict[str, Any]: async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server. """List text channels in the connected Discord server.
@ -22,13 +41,14 @@ def create_list_discord_channels_tool(
Returns: Returns:
Dictionary with status and a list of channels (id, name). Dictionary with status and a list of channels (id, name).
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
connector = await get_discord_connector( connector = await get_discord_connector(
db_session, search_space_id, user_id db_session, search_space_id, user_id
) )

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_discord_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_discord_messages tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_discord_messages( async def read_discord_messages(
channel_id: str, channel_id: str,
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
Dictionary with status and a list of messages including Dictionary with status and a list of messages including
id, author, content, timestamp. id, author, content, timestamp.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
@ -39,6 +58,7 @@ def create_read_discord_messages_tool(
limit = min(limit, 50) limit = min(limit, 50)
try: try:
async with async_session_maker() as db_session:
connector = await get_discord_connector( connector = await get_discord_connector(
db_session, search_space_id, user_id db_session, search_space_id, user_id
) )

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector from ._auth import DISCORD_API, get_bot_token, get_discord_connector
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_discord_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_discord_message tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_discord_message( async def send_discord_message(
channel_id: str, channel_id: str,
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
@ -47,6 +65,7 @@ def create_send_discord_message_tool(
} }
try: try:
async with async_session_maker() as db_session:
connector = await get_discord_connector( connector = await get_discord_connector(
db_session, search_space_id, user_id db_session, search_space_id, user_id
) )

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient from app.connectors.dropbox.client import DropboxClient
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_dropbox_file( async def create_dropbox_file(
name: str, name: str,
@ -82,13 +99,14 @@ def create_create_dropbox_file_tool(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'" f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Dropbox tool not properly configured.", "message": "Dropbox tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
@ -149,7 +167,9 @@ def create_create_dropbox_file_tool(
] ]
except Exception: except Exception:
logger.warning( logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True "Error fetching folders for connector %s",
cid,
exc_info=True,
) )
parent_folders[cid] = [] parent_folders[cid] = []
@ -217,7 +237,9 @@ def create_create_dropbox_file_tool(
) )
if final_file_type == "paper": if final_file_type == "paper":
created = await client.create_paper_doc(file_path, final_content or "") created = await client.create_paper_doc(
file_path, final_content or ""
)
file_id = created.get("file_id", "") file_id = created.get("file_id", "")
web_url = created.get("url", "") web_url = created.get("url", "")
else: else:
@ -246,7 +268,9 @@ def create_create_dropbox_file_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType, DocumentType,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
async_session_maker,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_dropbox_file( async def delete_dropbox_file(
file_name: str, file_name: str,
@ -55,13 +73,14 @@ def create_delete_dropbox_file_tool(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Dropbox tool not properly configured.", "message": "Dropbox tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
doc_result = await db_session.execute( doc_result = await db_session.execute(
select(Document) select(Document)
.join( .join(
@ -193,14 +212,17 @@ def create_delete_dropbox_file_tool(
final_file_path = result.params.get("file_path", file_path) final_file_path = result.params.get("file_path", file_path)
final_connector_id = result.params.get("connector_id", connector.id) final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id: if final_connector_id != connector.id:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
and_( and_(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR, == SearchSourceConnectorType.DROPBOX_CONNECTOR,
@ -221,7 +243,9 @@ def create_delete_dropbox_file_tool(
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
) )
client = DropboxClient(session=db_session, connector_id=actual_connector_id) client = DropboxClient(
session=db_session, connector_id=actual_connector_id
)
await client.delete_file(final_file_path) await client.delete_file(final_file_path)
logger.info(f"Dropbox file deleted: path={final_file_path}") logger.info(f"Dropbox file deleted: path={final_file_path}")

View file

@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService, ImageGenRouterService,
is_image_gen_auto_mode, is_image_gen_auto_mode,
) )
from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
} }
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string( def _build_model_string(
provider: str, model_name: str, custom_provider: str | None provider: str, model_name: str, custom_provider: str | None
) -> str: ) -> str:
if custom_provider: prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}" return f"{prefix}/{model_name}"
@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found" "error": f"Image generation config {config_id} not found"
} }
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("provider", ""), cfg.get("custom_provider")
cfg["model_name"],
cfg.get("custom_provider"),
) )
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key") gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"): api_base = resolve_api_base(
gen_kwargs["api_base"] = cfg["api_base"] provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"): if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"] gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"): if cfg.get("litellm_params"):
@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found" "error": f"Image generation config {config_id} not found"
} }
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.provider.value, db_cfg.custom_provider
db_cfg.model_name,
db_cfg.custom_provider,
) )
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base: api_base = resolve_api_base(
gen_kwargs["api_base"] = db_cfg.api_base provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version: if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params: if db_cfg.litellm_params:

View file

@ -0,0 +1,41 @@
from typing import Any
from app.db import SearchSourceConnector
from app.services.composio_service import ComposioService
def split_recipients(value: str | None) -> list[str]:
if not value:
return []
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
def unwrap_composio_data(data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def execute_composio_gmail_tool(
connector: SearchSourceConnector,
user_id: str,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return None, "Composio connected account ID not found for this Gmail connector."
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return unwrap_composio_data(result.get("data")), None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_gmail_draft( async def create_gmail_draft(
to: str, to: str,
@ -57,20 +75,23 @@ def create_create_gmail_draft_tool(
""" """
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session) metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", []) accounts = context.get("accounts", [])
@ -157,16 +178,13 @@ def create_create_gmail_draft_tool(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
) )
if ( is_composio_gmail = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
@ -186,13 +204,17 @@ def create_create_gmail_draft_tool(
config_data["token"] config_data["token"]
) )
if config_data.get("refresh_token"): if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token( config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"] config_data["refresh_token"]
) )
)
if config_data.get("client_secret"): if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token( config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"] config_data["client_secret"]
) )
)
exp = config_data.get("expiry", "") exp = config_data.get("expiry", "")
if exp: if exp:
@ -208,10 +230,6 @@ def create_create_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body) message = MIMEText(final_body)
message["to"] = final_to message["to"] = final_to
message["subject"] = final_subject message["subject"] = final_subject
@ -222,6 +240,34 @@ def create_create_gmail_draft_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode() raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try: try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
created, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_CREATE_EMAIL_DRAFT",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(created, dict):
created = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
created = await asyncio.get_event_loop().run_in_executor( created = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
@ -285,7 +331,9 @@ def create_create_gmail_draft_tool(
draft_id=created.get("id"), draft_id=created.get("id"),
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -5,7 +5,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_gmail_email(message_id: str) -> dict[str, Any]: async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID. """Read the full content of a specific Gmail email by its message ID.
@ -32,10 +49,11 @@ def create_read_gmail_email_tool(
Returns: Returns:
Dictionary with status and the full email content formatted as markdown. Dictionary with status and the full email content formatted as markdown.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."} return {"status": "error", "message": "Gmail tool not properly configured."}
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
@ -50,7 +68,57 @@ def create_read_gmail_email_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
} }
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.agents.new_chat.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
service = ComposioService()
detail, error = await service.get_gmail_message_detail(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if error:
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
summary = _format_gmail_summary(detail)
content = (
f"# {summary['subject']}\n\n"
f"**From:** {summary['from']}\n"
f"**To:** {summary['to']}\n"
f"**Date:** {summary['date']}\n\n"
f"## Message Content\n\n"
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {summary['message_id']}\n"
f"- **Thread ID:** {summary['thread_id']}\n"
)
return {
"status": "success",
"message_id": summary["message_id"] or message_id,
"content": content,
}
from app.agents.new_chat.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector) creds = _build_credentials(connector)
@ -84,7 +152,11 @@ def create_read_gmail_email_tool(
content = gmail.format_message_to_markdown(detail) content = gmail.format_message_to_markdown(detail)
return {"status": "success", "message_id": message_id, "content": content} return {
"status": "success",
"message_id": message_id,
"content": content,
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,7 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials raise ValueError("Composio connectors must use Composio tool execution.")
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
) )
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
headers = message.get("payload", {}).get("headers", [])
return {
header.get("name", "").lower(): header.get("value", "")
for header in headers
if isinstance(header, dict)
}
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
headers = _gmail_headers(message)
return {
"message_id": message.get("id") or message.get("messageId"),
"thread_id": message.get("threadId"),
"subject": message.get("subject") or headers.get("subject", "No Subject"),
"from": message.get("sender") or headers.get("from", "Unknown"),
"to": message.get("to") or headers.get("to", ""),
"date": message.get("messageTimestamp") or headers.get("date", ""),
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
"labels": message.get("labelIds", []),
}
async def _search_composio_gmail(
connector: SearchSourceConnector,
user_id: str,
query: str,
max_results: int,
) -> dict[str, Any]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.services.composio_service import ComposioService
service = ComposioService()
messages, _next_token, _estimate, error = await service.get_gmail_messages(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=max_results,
)
if error:
return {"status": "error", "message": error}
emails = [_format_gmail_summary(message) for message in messages]
return {
"status": "success",
"emails": emails,
"total": len(emails),
"message": "No emails found." if not emails else None,
}
def create_search_gmail_tool( def create_search_gmail_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the search_gmail tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_gmail tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def search_gmail( async def search_gmail(
query: str, query: str,
@ -90,12 +159,13 @@ def create_search_gmail_tool(
Dictionary with status and a list of email summaries including Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet. message_id, subject, from, date, snippet.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."} return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20) max_results = min(max_results, 20)
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
@ -110,6 +180,14 @@ def create_search_gmail_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
} }
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
return await _search_composio_gmail(
connector, str(user_id), query, max_results
)
creds = _build_credentials(connector) creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector from app.connectors.google_gmail_connector import GoogleGmailConnector

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_gmail_email( async def send_gmail_email(
to: str, to: str,
@ -58,20 +76,23 @@ def create_send_gmail_email_tool(
""" """
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session) metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", []) accounts = context.get("accounts", [])
@ -158,16 +179,13 @@ def create_send_gmail_email_tool(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
) )
if ( is_composio_gmail = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
@ -187,13 +205,17 @@ def create_send_gmail_email_tool(
config_data["token"] config_data["token"]
) )
if config_data.get("refresh_token"): if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token( config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"] config_data["refresh_token"]
) )
)
if config_data.get("client_secret"): if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token( config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"] config_data["client_secret"]
) )
)
exp = config_data.get("expiry", "") exp = config_data.get("expiry", "")
if exp: if exp:
@ -209,10 +231,6 @@ def create_send_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body) message = MIMEText(final_body)
message["to"] = final_to message["to"] = final_to
message["subject"] = final_subject message["subject"] = final_subject
@ -223,6 +241,34 @@ def create_send_gmail_email_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode() raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try: try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
sent, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_SEND_EMAIL",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(sent, dict):
sent = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
sent = await asyncio.get_event_loop().run_in_executor( sent = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
@ -286,7 +332,9 @@ def create_send_gmail_email_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -7,6 +7,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the trash_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured trash_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def trash_gmail_email( async def trash_gmail_email(
email_subject_or_id: str, email_subject_or_id: str,
@ -55,13 +73,14 @@ def create_trash_gmail_email_tool(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session) metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_trash_context( context = await metadata_service.get_trash_context(
search_space_id, user_id, email_subject_or_id search_space_id, user_id, email_subject_or_id
@ -122,7 +141,9 @@ def create_trash_gmail_email_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id: if not final_connector_id:
return { return {
@ -158,16 +179,13 @@ def create_trash_gmail_email_tool(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}" f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
) )
if ( is_composio_gmail = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
@ -187,13 +205,17 @@ def create_trash_gmail_email_tool(
config_data["token"] config_data["token"]
) )
if config_data.get("refresh_token"): if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token( config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"] config_data["refresh_token"]
) )
)
if config_data.get("client_secret"): if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token( config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"] config_data["client_secret"]
) )
)
exp = config_data.get("expiry", "") exp = config_data.get("expiry", "")
if exp: if exp:
@ -209,11 +231,24 @@ def create_trash_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
_trashed, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_MOVE_TO_TRASH",
{"user_id": "me", "message_id": final_message_id},
)
if error:
raise RuntimeError(error)
else:
from googleapiclient.discovery import build from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds) gmail_service = build("gmail", "v1", credentials=creds)
try:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the update_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_gmail_draft( async def update_gmail_draft(
draft_subject_or_id: str, draft_subject_or_id: str,
@ -76,13 +94,14 @@ def create_update_gmail_draft_tool(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session) metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, draft_subject_or_id search_space_id, user_id, draft_subject_or_id
@ -188,16 +207,13 @@ def create_update_gmail_draft_tool(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}" f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
) )
if ( is_composio_gmail = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
@ -217,13 +233,17 @@ def create_update_gmail_draft_tool(
config_data["token"] config_data["token"]
) )
if config_data.get("refresh_token"): if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token( config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"] config_data["refresh_token"]
) )
)
if config_data.get("client_secret"): if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token( config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"] config_data["client_secret"]
) )
)
exp = config_data.get("expiry", "") exp = config_data.get("expiry", "")
if exp: if exp:
@ -239,15 +259,19 @@ def create_update_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
# Resolve draft_id if not already available # Resolve draft_id if not already available
if not final_draft_id: if not final_draft_id:
logger.info( logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
) )
if is_composio_gmail:
final_draft_id = await _find_composio_draft_id_by_message(
connector, user_id, message_id
)
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
final_draft_id = await _find_draft_id_by_message( final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id gmail_service, message_id
) )
@ -272,6 +296,35 @@ def create_update_gmail_draft_tool(
raw = base64.urlsafe_b64encode(message.as_bytes()).decode() raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try: try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
updated, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_UPDATE_DRAFT",
{
"user_id": "me",
"draft_id": final_draft_id,
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(updated, dict):
updated = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
updated = await asyncio.get_event_loop().run_in_executor( updated = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
except Exception as e: except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}") logger.warning(f"Failed to look up draft by message_id: {e}")
return None return None
async def _find_composio_draft_id_by_message(
connector: Any, user_id: str, message_id: str
) -> str | None:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await execute_composio_gmail_tool(
connector, user_id, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_calendar_event( async def create_calendar_event(
summary: str, summary: str,
@ -60,20 +78,23 @@ def create_create_calendar_event_tool(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session) metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", []) accounts = context.get("accounts", [])
@ -113,7 +134,9 @@ def create_create_calendar_event_tool(
} }
final_summary = result.params.get("summary", summary) final_summary = result.params.get("summary", summary)
final_start_datetime = result.params.get("start_datetime", start_datetime) final_start_datetime = result.params.get(
"start_datetime", start_datetime
)
final_end_datetime = result.params.get("end_datetime", end_datetime) final_end_datetime = result.params.get("end_datetime", end_datetime)
final_description = result.params.get("description", description) final_description = result.params.get("description", description)
final_location = result.params.get("location", location) final_location = result.params.get("location", location)
@ -121,7 +144,10 @@ def create_create_calendar_event_tool(
final_connector_id = result.params.get("connector_id") final_connector_id = result.params.get("connector_id")
if not final_summary or not final_summary.strip(): if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Event summary cannot be empty."} return {
"status": "error",
"message": "Event summary cannot be empty.",
}
from sqlalchemy.future import select from sqlalchemy.future import select
@ -168,16 +194,13 @@ def create_create_calendar_event_tool(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
) )
if ( is_composio_calendar = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
@ -211,10 +234,6 @@ def create_create_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
tz = context.get("timezone", "UTC") tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = { event_body: dict[str, Any] = {
"summary": final_summary, "summary": final_summary,
@ -231,6 +250,43 @@ def create_create_calendar_event_tool(
] ]
try: try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params = {
"calendar_id": "primary",
"summary": final_summary,
"start_datetime": final_start_datetime,
"end_datetime": final_end_datetime,
"timezone": tz,
"attendees": final_attendees or [],
}
if final_description:
composio_params["description"] = final_description
if final_location:
composio_params["location"] = final_location
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_CREATE_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
created = composio_result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
created = await asyncio.get_event_loop().run_in_executor( created = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
@ -295,7 +351,9 @@ def create_create_calendar_event_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_calendar_event( async def delete_calendar_event(
event_title_or_id: str, event_title_or_id: str,
@ -54,13 +72,14 @@ def create_delete_calendar_event_tool(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session) metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_deletion_context( context = await metadata_service.get_deletion_context(
search_space_id, user_id, event_title_or_id search_space_id, user_id, event_title_or_id
@ -121,7 +140,9 @@ def create_delete_calendar_event_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id: if not final_connector_id:
return { return {
@ -159,16 +180,13 @@ def create_delete_calendar_event_tool(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}" f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
) )
if ( is_composio_calendar = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
@ -202,11 +220,29 @@ def create_delete_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_DELETE_EVENT",
params={
"calendar_id": "primary",
"event_id": final_event_id,
},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
else:
service = await asyncio.get_event_loop().run_in_executor( service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
) )
try:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (

View file

@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
] ]
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
if "T" in value:
return value
time = "23:59:59" if is_end else "00:00:00"
return f"{value}T{time}Z"
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return events
def create_search_calendar_events_tool( def create_search_calendar_events_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the search_calendar_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_calendar_events tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def search_calendar_events( async def search_calendar_events(
start_date: str, start_date: str,
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
Dictionary with status and a list of events including Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees. event_id, summary, start, end, location, attendees.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Calendar tool not properly configured.", "message": "Calendar tool not properly configured.",
@ -47,6 +93,7 @@ def create_search_calendar_events_tool(
max_results = min(max_results, 50) max_results = min(max_results, 50)
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
@ -61,9 +108,34 @@ def create_search_calendar_events_tool(
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
} }
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
from app.services.composio_service import ComposioService
events_raw, error = await ComposioService().get_calendar_events(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
time_min=_to_calendar_boundary(start_date, is_end=False),
time_max=_to_calendar_boundary(end_date, is_end=True),
max_results=max_results,
)
if not events_raw and not error:
error = "No events found in the specified date range."
else:
creds = _build_credentials(connector) creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.connectors.google_calendar_connector import (
GoogleCalendarConnector,
)
cal = GoogleCalendarConnector( cal = GoogleCalendarConnector(
credentials=creds, credentials=creds,
@ -97,24 +169,7 @@ def create_search_calendar_events_tool(
} }
return {"status": "error", "message": error} return {"status": "error", "message": error}
events = [] events = _format_calendar_events(events_raw)
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return {"status": "success", "events": events, "total": len(events)} return {"status": "success", "events": events, "total": len(events)}

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the update_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_calendar_event( async def update_calendar_event(
event_title_or_id: str, event_title_or_id: str,
@ -74,13 +92,14 @@ def create_update_calendar_event_tool(
""" """
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session) metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, event_title_or_id search_space_id, user_id, event_title_or_id
@ -192,16 +211,13 @@ def create_update_calendar_event_tool(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
) )
if ( is_composio_calendar = (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
@ -235,10 +251,6 @@ def create_update_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
update_body: dict[str, Any] = {} update_body: dict[str, Any] = {}
if final_new_summary is not None: if final_new_summary is not None:
update_body["summary"] = final_new_summary update_body["summary"] = final_new_summary
@ -247,7 +259,9 @@ def create_update_calendar_event_tool(
final_new_start_datetime, context final_new_start_datetime, context
) )
if final_new_end_datetime is not None: if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(final_new_end_datetime, context) update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
if final_new_description is not None: if final_new_description is not None:
update_body["description"] = final_new_description update_body["description"] = final_new_description
if final_new_location is not None: if final_new_location is not None:
@ -264,6 +278,53 @@ def create_update_calendar_event_tool(
} }
try: try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params: dict[str, Any] = {
"calendar_id": "primary",
"event_id": final_event_id,
}
if final_new_summary is not None:
composio_params["summary"] = final_new_summary
if final_new_start_datetime is not None:
composio_params["start_time"] = final_new_start_datetime
if final_new_end_datetime is not None:
composio_params["end_time"] = final_new_end_datetime
if final_new_description is not None:
composio_params["description"] = final_new_description
if final_new_location is not None:
composio_params["location"] = final_new_location
if final_new_attendees is not None:
composio_params["attendees"] = [
e.strip() for e in final_new_attendees if e.strip()
]
if not _is_date_only(
final_new_start_datetime or final_new_end_datetime or ""
):
composio_params["timezone"] = context.get("timezone", "UTC")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_PATCH_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
updated = composio_result.get("data", {})
if isinstance(updated, dict):
updated = updated.get("data", updated)
if isinstance(updated, dict):
updated = updated.get("response_data", updated)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
updated = await asyncio.get_event_loop().run_in_executor( updated = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
@ -314,7 +375,9 @@ def create_update_calendar_event_tool(
kb_message_suffix = "" kb_message_suffix = ""
if document_id is not None: if document_id is not None:
try: try:
from app.services.google_calendar import GoogleCalendarKBSyncService from app.services.google_calendar import (
GoogleCalendarKBSyncService,
)
kb_service = GoogleCalendarKBSyncService(db_session) kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update( kb_result = await kb_service.sync_after_update(

View file

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured create_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_google_drive_file( async def create_google_drive_file(
name: str, name: str,
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
f"create_google_drive_file called: name='{name}', type='{file_type}'" f"create_google_drive_file called: name='{name}', type='{file_type}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Drive tool not properly configured. Please contact support.", "message": "Google Drive tool not properly configured. Please contact support.",
@ -78,18 +98,23 @@ def create_create_google_drive_file_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session) metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", []) accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts): if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Google Drive accounts have expired authentication") logger.warning(
"All Google Drive accounts have expired authentication"
)
return { return {
"status": "auth_error", "status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.", "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
@ -179,23 +204,53 @@ def create_create_google_drive_file_tool(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
) )
pre_built_creds = None is_composio_drive = (
if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
pre_built_creds = build_composio_credentials(cca_id) return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient( client = GoogleDriveClient(
session=db_session, session=db_session,
connector_id=actual_connector_id, connector_id=actual_connector_id,
credentials=pre_built_creds,
) )
try: try:
if is_composio_drive:
from app.services.composio_service import ComposioService
params: dict[str, Any] = {
"name": final_name,
"mimeType": mime_type,
"fields": "id,name,webViewLink,mimeType",
}
if final_parent_folder_id:
params["parents"] = [final_parent_folder_id]
if final_content:
params["description"] = final_content[:4096]
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_CREATE_FILE",
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
created = result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
if not isinstance(created, dict):
created = {}
else:
created = await client.create_file( created = await client.create_file(
name=final_name, name=final_name,
mime_type=mime_type, mime_type=mime_type,
@ -253,7 +308,9 @@ def create_create_google_drive_file_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.client import GoogleDriveClient
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured delete_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_google_drive_file( async def delete_google_drive_file(
file_name: str, file_name: str,
@ -55,13 +75,14 @@ def create_delete_google_drive_file_tool(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Drive tool not properly configured. Please contact support.", "message": "Google Drive tool not properly configured. Please contact support.",
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session) metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_trash_context( context = await metadata_service.get_trash_context(
search_space_id, user_id, file_name search_space_id, user_id, file_name
@ -122,7 +143,9 @@ def create_delete_google_drive_file_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id: if not final_connector_id:
return { return {
@ -158,23 +181,37 @@ def create_delete_google_drive_file_tool(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
) )
pre_built_creds = None is_composio_drive = (
if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
): )
from app.utils.google_credentials import build_composio_credentials if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
pre_built_creds = build_composio_credentials(cca_id) return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient( client = GoogleDriveClient(
session=db_session, session=db_session,
connector_id=connector.id, connector_id=connector.id,
credentials=pre_built_creds,
) )
try: try:
if is_composio_drive:
from app.services.composio_service import ComposioService
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_TRASH_FILE",
params={"file_id": final_file_id},
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
else:
await client.trash_file(file_id=final_file_id) await client.trash_file(file_id=final_file_id)
except HttpError as http_err: except HttpError as http_err:
if http_err.resp.status == 403: if http_err.resp.status == 403:

View file

@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{ {
"create_gmail_draft", "create_gmail_draft",
"update_gmail_draft", "update_gmail_draft",
"create_calendar_event",
"create_notion_page", "create_notion_page",
"create_confluence_page", "create_confluence_page",
"create_google_drive_file", "create_google_drive_file",

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the create_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Jira API
blocking.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_jira_issue( async def create_jira_issue(
project_key: str, project_key: str,
@ -49,10 +72,11 @@ def create_create_jira_issue_tool(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'" f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session) metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
@ -97,7 +121,10 @@ def create_create_jira_issue_tool(
final_connector_id = result.params.get("connector_id", connector_id) final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip(): if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Issue summary cannot be empty."} return {
"status": "error",
"message": "Issue summary cannot be empty.",
}
if not final_project_key: if not final_project_key:
return {"status": "error", "message": "A project must be selected."} return {"status": "error", "message": "A project must be selected."}
@ -117,7 +144,10 @@ def create_create_jira_issue_tool(
) )
connector = result.scalars().first() connector = result.scalars().first()
if not connector: if not connector:
return {"status": "error", "message": "No Jira connector found."} return {
"status": "error",
"message": "No Jira connector found.",
}
actual_connector_id = connector.id actual_connector_id = connector.id
else: else:
result = await db_session.execute( result = await db_session.execute(
@ -188,7 +218,9 @@ def create_create_jira_issue_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the delete_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_jira_issue( async def delete_jira_issue(
issue_title_or_key: str, issue_title_or_key: str,
@ -44,10 +65,11 @@ def create_delete_jira_issue_tool(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'" f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session) metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_deletion_context( context = await metadata_service.get_deletion_context(
search_space_id, user_id, issue_title_or_key search_space_id, user_id, issue_title_or_key
@ -92,7 +114,9 @@ def create_delete_jira_issue_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
from sqlalchemy.future import select from sqlalchemy.future import select
@ -129,7 +153,10 @@ def create_delete_jira_issue_tool(
except Exception as api_err: except Exception as api_err:
if "status code 403" in str(api_err).lower(): if "status code 403" in str(api_err).lower():
try: try:
connector.config = {**connector.config, "auth_expired": True} connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config") flag_modified(connector, "config")
await db_session.commit() await db_session.commit()
except Exception: except Exception:

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the update_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured update_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_jira_issue( async def update_jira_issue(
issue_title_or_key: str, issue_title_or_key: str,
@ -48,10 +69,11 @@ def create_update_jira_issue_tool(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'" f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session) metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, issue_title_or_key search_space_id, user_id, issue_title_or_key
@ -97,7 +119,9 @@ def create_update_jira_issue_tool(
final_issue_key = result.params.get("issue_key", issue_key) final_issue_key = result.params.get("issue_key", issue_key)
final_summary = result.params.get("new_summary", new_summary) final_summary = result.params.get("new_summary", new_summary)
final_description = result.params.get("new_description", new_description) final_description = result.params.get(
"new_description", new_description
)
final_priority = result.params.get("new_priority", new_priority) final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
@ -140,7 +164,9 @@ def create_update_jira_issue_tool(
"content": [ "content": [
{ {
"type": "paragraph", "type": "paragraph",
"content": [{"type": "text", "text": final_description}], "content": [
{"type": "text", "text": final_description}
],
} }
], ],
} }
@ -161,7 +187,10 @@ def create_update_jira_issue_tool(
except Exception as api_err: except Exception as api_err:
if "status code 403" in str(api_err).lower(): if "status code 403" in str(api_err).lower():
try: try:
connector.config = {**connector.config, "auth_expired": True} connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config") flag_modified(connector, "config")
await db_session.commit() await db_session.commit()
except Exception: except Exception:

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the create_linear_issue tool.
Factory function to create the create_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
Returns: Returns:
Configured create_linear_issue tool Configured create_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def create_linear_issue( async def create_linear_issue(
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
""" """
logger.info(f"create_linear_issue called: title='{title}'") logger.info(f"create_linear_issue called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -75,13 +83,16 @@ def create_create_linear_issue_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session) metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", []) workspaces = context.get("workspaces", [])
@ -128,7 +139,10 @@ def create_create_linear_issue_tool(
if not final_title or not final_title.strip(): if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace") logger.error("Title is empty or contains only whitespace")
return {"status": "error", "message": "Issue title cannot be empty."} return {
"status": "error",
"message": "Issue title cannot be empty.",
}
if not final_team_id: if not final_team_id:
return { return {
"status": "error", "status": "error",
@ -192,7 +206,9 @@ def create_create_linear_issue_tool(
) )
if result.get("status") == "error": if result.get("status") == "error":
logger.error(f"Failed to create Linear issue: {result.get('message')}") logger.error(
f"Failed to create Linear issue: {result.get('message')}"
)
return {"status": "error", "message": result.get("message")} return {"status": "error", "message": result.get("message")}
logger.info( logger.info(
@ -215,7 +231,9 @@ def create_create_linear_issue_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the delete_linear_issue tool.
Factory function to create the delete_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
Returns: Returns:
Configured delete_linear_issue tool Configured delete_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def delete_linear_issue( async def delete_linear_issue(
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -83,6 +91,7 @@ def create_delete_linear_issue_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session) metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_delete_context( context = await metadata_service.get_delete_context(
search_space_id, user_id, issue_ref search_space_id, user_id, issue_ref
@ -136,7 +145,9 @@ def create_delete_linear_issue_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info( logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, " f"Deleting Linear issue with final params: issue_id={final_issue_id}, "

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearKBSyncService, LinearToolMetadataService from app.services.linear import LinearKBSyncService, LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the update_linear_issue tool.
Factory function to create the update_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
Returns: Returns:
Configured update_linear_issue tool Configured update_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def update_linear_issue( async def update_linear_issue(
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
""" """
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -96,6 +104,7 @@ def create_update_linear_issue_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session) metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, issue_ref search_space_id, user_id, issue_ref

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_luma_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_luma_event( async def create_luma_event(
name: str, name: str,
@ -40,11 +58,14 @@ def create_create_luma_event_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Luma connector found."} return {"status": "error", "message": "No Luma connector found."}

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_luma_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_luma_events tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_luma_events( async def list_luma_events(
max_results: int = 25, max_results: int = 25,
@ -28,13 +47,16 @@ def create_list_luma_events_tool(
Dictionary with status and a list of events including Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url. event_id, name, start_at, end_at, location, url.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50) max_results = min(max_results, 50)
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Luma connector found."} return {"status": "error", "message": "No Luma connector found."}

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_luma_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_luma_event(event_id: str) -> dict[str, Any]: async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event. """Read detailed information about a specific Luma event.
@ -26,11 +45,14 @@ def create_read_luma_event_tool(
Dictionary with status and full event details including Dictionary with status and full event details including
description, attendees count, meeting URL. description, attendees count, meeting URL.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Luma connector found."} return {"status": "error", "message": "No Luma connector found."}

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
""" """
Factory function to create the create_notion_page tool. Factory function to create the create_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Notion API
blocking.
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
Returns: Returns:
Configured create_notion_page tool Configured create_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def create_notion_page( async def create_notion_page(
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
""" """
logger.info(f"create_notion_page called: title='{title}'") logger.info(f"create_notion_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -77,13 +88,16 @@ def create_create_notion_page_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session) metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_creation_context( context = await metadata_service.get_creation_context(
search_space_id, user_id search_space_id, user_id
) )
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return { return {
"status": "error", "status": "error",
"message": context["error"], "message": context["error"],

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
""" """
Factory function to create the delete_notion_page tool. Factory function to create the delete_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
Returns: Returns:
Configured delete_notion_page tool Configured delete_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def delete_notion_page( async def delete_notion_page(
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -73,6 +81,7 @@ def create_delete_notion_page_tool(
} }
try: try:
async with async_session_maker() as db_session:
# Get page context (page_id, account, title) from indexed data # Get page context (page_id, account, title) from indexed data
metadata_service = NotionToolMetadataService(db_session) metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_delete_context( context = await metadata_service.get_delete_context(
@ -136,7 +145,9 @@ def create_delete_notion_page_tool(
final_connector_id = result.params.get( final_connector_id = result.params.get(
"connector_id", connector_id_from_context "connector_id", connector_id_from_context
) )
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info( logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
""" """
Factory function to create the update_notion_page tool. Factory function to create the update_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache (see
``create_create_notion_page_tool`` for the full rationale).
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
Returns: Returns:
Configured update_notion_page tool Configured update_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def update_notion_page( async def update_notion_page(
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -88,6 +96,7 @@ def create_update_notion_page_tool(
} }
try: try:
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session) metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_update_context( context = await metadata_service.get_update_context(
search_space_id, user_id, page_title search_space_id, user_id, page_title
@ -204,7 +213,9 @@ def create_update_notion_page_tool(
if result.get("status") == "success" and document_id is not None: if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService from app.services.notion import NotionKBSyncService
logger.info(f"Updating knowledge base for document {document_id}...") logger.info(
f"Updating knowledge base for document {document_id}..."
)
kb_service = NotionKBSyncService(db_session) kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update( kb_result = await kb_service.sync_after_update(
document_id=document_id, document_id=document_id,

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient from app.connectors.onedrive.client import OneDriveClient
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_onedrive_file( async def create_onedrive_file(
name: str, name: str,
@ -70,13 +87,14 @@ def create_create_onedrive_file_tool(
""" """
logger.info(f"create_onedrive_file called: name='{name}'") logger.info(f"create_onedrive_file called: name='{name}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "OneDrive tool not properly configured.", "message": "OneDrive tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
@ -136,7 +154,9 @@ def create_create_onedrive_file_tool(
] ]
except Exception: except Exception:
logger.warning( logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True "Error fetching folders for connector %s",
cid,
exc_info=True,
) )
parent_folders[cid] = [] parent_folders[cid] = []
@ -223,7 +243,9 @@ def create_create_onedrive_file_tool(
user_id=user_id, user_id=user_id,
) )
if kb_result["status"] == "success": if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated." kb_message_suffix = (
" Your knowledge base has also been updated."
)
else: else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err: except Exception as kb_err:

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType, DocumentType,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
async_session_maker,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_onedrive_file( async def delete_onedrive_file(
file_name: str, file_name: str,
@ -56,13 +74,14 @@ def create_delete_onedrive_file_tool(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "OneDrive tool not properly configured.", "message": "OneDrive tool not properly configured.",
} }
try: try:
async with async_session_maker() as db_session:
doc_result = await db_session.execute( doc_result = await db_session.execute(
select(Document) select(Document)
.join( .join(
@ -95,7 +114,9 @@ def create_delete_onedrive_file_tool(
Document.document_type == DocumentType.ONEDRIVE_FILE, Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower( func.lower(
cast( cast(
Document.document_metadata["onedrive_file_name"], Document.document_metadata[
"onedrive_file_name"
],
String, String,
) )
) )
@ -193,14 +214,17 @@ def create_delete_onedrive_file_tool(
final_file_id = result.params.get("file_id", file_id) final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get("connector_id", connector.id) final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id: if final_connector_id != connector.id:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
and_( and_(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR, == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,

View file

@ -824,13 +824,22 @@ async def build_tools_async(
"""Async version of build_tools that also loads MCP tools from database. """Async version of build_tools that also loads MCP tools from database.
Design Note: Design Note:
This function exists because MCP tools require database queries to load user configs, This function exists because MCP tools require database queries to load
while built-in tools are created synchronously from static code. user configs, while built-in tools are created synchronously from static
code.
Alternative: We could make build_tools() itself async and always query the database, Alternative: We could make build_tools() itself async and always query
but that would force async everywhere even when only using built-in tools. The current the database, but that would force async everywhere even when only using
design keeps the simple case (static tools only) synchronous while supporting dynamic built-in tools. The current design keeps the simple case (static tools
database-loaded tools through this async wrapper. only) synchronous while supporting dynamic database-loaded tools through
this async wrapper.
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
the event loop) are kicked off concurrently. Cold-path savings are
bounded by the slower of the two typically MCP at ~200ms-1.7s
so the parallelization recovers the ~50-200ms previously spent
serially on built-in construction.
Args: Args:
dependencies: Dict containing all possible dependencies dependencies: Dict containing all possible dependencies
@ -843,33 +852,70 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools. List of configured tool instances ready for the agent, including MCP tools.
""" """
import asyncio
import time import time
_perf_log = logging.getLogger("surfsense.perf") _perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG) _perf_log.setLevel(logging.DEBUG)
can_load_mcp = (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
)
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
# function over its inputs — safe to thread-shift.
_t0 = time.perf_counter() _t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) builtin_task = asyncio.create_task(
asyncio.to_thread(
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
)
)
mcp_task: asyncio.Task | None = None
if can_load_mcp:
mcp_task = asyncio.create_task(
load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
)
# Surface failures from each task independently so a flaky MCP
# endpoint never poisons built-in tool registration. ``return_exceptions``
# gives us per-task exceptions instead of dropping the second result
# when the first raises.
if mcp_task is not None:
builtin_result, mcp_result = await asyncio.gather(
builtin_task, mcp_task, return_exceptions=True
)
else:
builtin_result = await builtin_task
mcp_result = None
if isinstance(builtin_result, BaseException):
raise builtin_result # built-in registration failure is non-recoverable
tools: list[BaseTool] = builtin_result
_perf_log.info( _perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)", "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
len(tools), len(tools),
) )
# Load MCP tools if requested and dependencies are available if mcp_task is not None:
if ( if isinstance(mcp_result, BaseException):
include_mcp_tools # ``return_exceptions=True`` captures the exception out-of-band,
and "db_session" in dependencies # so ``sys.exc_info()`` is empty here. Pass the captured
and "search_space_id" in dependencies # exception via ``exc_info=`` to get a real traceback.
): logging.error(
try: "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
) )
else:
mcp_tools = mcp_result or []
_perf_log.info( _perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)", "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
len(mcp_tools), len(mcp_tools),
) )
@ -879,8 +925,6 @@ async def build_tools_async(
len(mcp_tools), len(mcp_tools),
[t.name for t in mcp_tools], [t.name for t in mcp_tools],
) )
except Exception as e:
logging.exception("Failed to load MCP tools: %s", e)
logging.info( logging.info(
"Total tools for agent: %d%s", "Total tools for agent: %d%s",

View file

@ -15,7 +15,7 @@ from langchain_core.tools import tool
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text from app.utils.document_converters import embed_text
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
""" """
Factory function to create the search_surfsense_docs tool. Factory function to create the search_surfsense_docs tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for executing queries db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns: Returns:
A configured tool function for searching Surfsense documentation A configured tool function for searching Surfsense documentation
""" """
del db_session # per-call session — see docstring
@tool @tool
async def search_surfsense_docs(query: str, top_k: int = 10) -> str: async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
@ -155,6 +162,7 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
Returns: Returns:
Relevant documentation content formatted with chunk IDs for citations Relevant documentation content formatted with chunk IDs for citations
""" """
async with async_session_maker() as db_session:
return await search_surfsense_docs_async( return await search_surfsense_docs_async(
query=query, query=query,
db_session=db_session, db_session=db_session,

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_teams_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_teams_channels tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_teams_channels() -> dict[str, Any]: async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to. """List all Microsoft Teams and their channels the user has access to.
@ -23,11 +42,14 @@ def create_list_teams_channels_tool(
Dictionary with status and a list of teams, each containing Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name). team_id, team_name, and a list of channels (id, name).
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Teams connector found."} return {"status": "error", "message": "No Teams connector found."}

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_teams_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_teams_messages tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_teams_messages( async def read_teams_messages(
team_id: str, team_id: str,
@ -32,13 +51,16 @@ def create_read_teams_messages_tool(
Dictionary with status and a list of messages including Dictionary with status and a list of messages including
id, sender, content, timestamp. id, sender, content, timestamp.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50) limit = min(limit, 50)
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Teams connector found."} return {"status": "error", "message": "No Teams connector found."}

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_teams_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_teams_message tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_teams_message( async def send_teams_message(
team_id: str, team_id: str,
@ -39,11 +57,14 @@ def create_send_teams_message_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector: if not connector:
return {"status": "error", "message": "No Teams connector found."} return {"status": "error", "message": "No Teams connector found."}

View file

@ -26,7 +26,7 @@ from langchain_core.tools import tool
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User from app.db import SearchSpace, User, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -295,6 +295,25 @@ def create_update_memory_tool(
db_session: AsyncSession, db_session: AsyncSession,
llm: Any | None = None, llm: Any | None = None,
): ):
"""Factory function to create the user-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
user_id: ID of the user whose memory document is being updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the user-memory scope.
"""
del db_session # per-call session — see docstring
uid = UUID(user_id) if isinstance(user_id, str) else user_id uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool @tool
@ -311,6 +330,7 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff). updated_memory: The FULL updated markdown document (not a diff).
""" """
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute(select(User).where(User.id == uid)) result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first() user = result.scalars().first()
if not user: if not user:
@ -330,7 +350,6 @@ def create_update_memory_tool(
) )
except Exception as e: except Exception as e:
logger.exception("Failed to update user memory: %s", e) logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"message": f"Failed to update memory: {e}", "message": f"Failed to update memory: {e}",
@ -344,6 +363,27 @@ def create_update_team_memory_tool(
db_session: AsyncSession, db_session: AsyncSession,
llm: Any | None = None, llm: Any | None = None,
): ):
"""Factory function to create the team-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
search_space_id: ID of the search space whose team memory is being
updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the team-memory scope.
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_memory(updated_memory: str) -> dict[str, Any]: async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space. """Update the team's shared memory document for this search space.
@ -359,6 +399,7 @@ def create_update_team_memory_tool(
updated_memory: The FULL updated markdown document (not a diff). updated_memory: The FULL updated markdown document (not a diff).
""" """
try: try:
async with async_session_maker() as db_session:
result = await db_session.execute( result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id) select(SearchSpace).where(SearchSpace.id == search_space_id)
) )
@ -372,7 +413,9 @@ def create_update_team_memory_tool(
updated_memory=updated_memory, updated_memory=updated_memory,
old_memory=old_memory, old_memory=old_memory,
llm=llm, llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content), apply_fn=lambda content: setattr(
space, "shared_memory_md", content
),
commit_fn=db_session.commit, commit_fn=db_session.commit,
rollback_fn=db_session.rollback, rollback_fn=db_session.rollback,
label="team memory", label="team memory",
@ -380,7 +423,6 @@ def create_update_team_memory_tool(
) )
except Exception as e: except Exception as e:
logger.exception("Failed to update team memory: %s", e) logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"message": f"Failed to update team memory: {e}", "message": f"Failed to update team memory: {e}",

View file

@ -421,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
OpenRouterIntegrationService.get_instance().stop_background_refresh() OpenRouterIntegrationService.get_instance().stop_background_refresh()
async def _warm_agent_jit_caches() -> None:
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
Why
----
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
generation chain takes 1.5-2 seconds of pure CPU on first invocation
inside any Python process: the graph compiler builds reducers,
Pydantic v2 generates and JITs validator schemas, deepagents
eagerly compiles its general-purpose subagent, etc. Subsequent
compiles in the same process pay only ~50% of that cost (the lazy
JIT bits are cached in module-level dicts).
Doing one throwaway compile during ``lifespan`` startup pre-pays
that cost so the *first real request* doesn't. We do NOT prime
:mod:`agent_cache` because the cache key requires real
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. the
throwaway agent is genuinely thrown away and immediately collected.
Safety
------
* No DB access. We construct a stub LLM (no real keys), pass an
empty tools list, and pass ``checkpointer=None`` so we never
touch Postgres.
* Bounded by ``asyncio.wait_for`` so a hang here can never block
worker startup. On any failure, we log + swallow the worst
case is the first real request pays the full cold cost (i.e.
pre-warmup behaviour).
"""
import time as _time
logger = logging.getLogger(__name__)
t0 = _time.perf_counter()
try:
from langchain.agents import create_agent
from langchain.agents.middleware import (
ModelCallLimitMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
)
from langchain_core.tools import tool
from app.agents.new_chat.context import SurfSenseContextSchema
# Minimal LLM stub. ``FakeListChatModel`` satisfies
# ``BaseChatModel`` without any network or auth — perfect for
# exercising the compile path without side effects.
stub_llm = FakeListChatModel(responses=["warmup-response"])
# Two trivial tools with arg + return schemas — exercises the
# Pydantic v2 schema JIT path. Without at least one tool the
# graph compile skips the tool-loop bytecode generation that
# accounts for ~30-50% of cold compile cost.
@tool
def _warmup_tool_a(query: str, limit: int = 5) -> str:
"""Warmup tool A — never actually invoked."""
return query[:limit]
@tool
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
"""Warmup tool B — never actually invoked."""
return {"name": name, "value": value}
# A handful of common middleware so the compile pre-pays the
# ``AgentMiddleware`` resolver path. These instances never run
# because the throwaway agent is immediately collected.
# ``SubAgentMiddleware`` is the single heaviest line in cold
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
# compile its general-purpose subagent's full inner graph),
# so we include it here to make sure that compile path is JIT'd.
warmup_middleware: list = [
TodoListMiddleware(),
ModelCallLimitMiddleware(
thread_limit=120, run_limit=80, exit_behavior="end"
),
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
),
]
try:
from deepagents import SubAgentMiddleware
from deepagents.backends import StateBackend
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
gp_warmup_spec = { # type: ignore[var-annotated]
**GENERAL_PURPOSE_SUBAGENT,
"model": stub_llm,
"tools": [_warmup_tool_a],
"middleware": [TodoListMiddleware()],
}
warmup_middleware.append(
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
)
except Exception:
# Deepagents missing/incompatible — middleware-only warmup
# still produces a useful (smaller) speedup.
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
compiled = create_agent(
stub_llm,
tools=[_warmup_tool_a, _warmup_tool_b],
system_prompt="You are a warmup stub.",
middleware=warmup_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=None,
)
# Touch the compiled graph's stream_channels / nodes so any
# remaining lazy schema work fires now instead of on first
# real invocation.
_ = list(getattr(compiled, "nodes", {}).keys())
del compiled
logger.info(
"[startup] Agent JIT warmup completed in %.3fs",
_time.perf_counter() - t0,
)
except Exception:
logger.warning(
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
"real request will pay the full compile cost)",
_time.perf_counter() - t0,
exc_info=True,
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected # Tune GC: lower gen-2 threshold so long-lived garbage is collected
@ -445,6 +574,18 @@ async def lifespan(app: FastAPI):
"Docs will be indexed on the next restart." "Docs will be indexed on the next restart."
) )
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
# worker readiness. ``shield`` so Uvicorn cancelling startup
# doesn't leave half-warmed Pydantic schemas in an inconsistent
# state.
try:
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
except (TimeoutError, Exception): # pragma: no cover - defensive
logging.getLogger(__name__).warning(
"[startup] Agent JIT warmup hit timeout/error — skipping; "
"first real request will pay the full compile cost."
)
log_system_snapshot("startup_complete") log_system_snapshot("startup_complete")
yield yield

View file

@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f) data = yaml.safe_load(f)
configs = data.get("global_llm_configs", []) configs = data.get("global_llm_configs", [])
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {} seen_slugs: dict[str, int] = {}
for cfg in configs: for cfg in configs:
cfg.setdefault("billing_tier", "free") cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False) cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False) cfg.setdefault("seo_enabled", False)
# Capability flag: explicit YAML override always wins. When the
# operator has not annotated the model, defer to LiteLLM's
# authoritative model map (`supports_vision`) which already
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
# vision-capable. Unknown / unmapped models default-allow so
# we don't lock the user out of a freshly added third-party
# entry; the streaming-task safety net (driven by
# `is_known_text_only_chat_model`) is the only place a False
# actually blocks a request.
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
if cfg.get("seo_enabled") and cfg.get("seo_slug"): if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"] slug = cfg["seo_slug"]

View file

@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.config import config
from app.db import User from app.db import User
from app.users import current_active_user from app.users import current_active_user
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool enable_otel: bool
enable_desktop_local_filesystem: bool
@classmethod @classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows. # asdict() avoids missing-field bugs when AgentFeatureFlags grows.
return cls(**asdict(flags)) return cls(
**asdict(flags),
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
)
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead) @router.get("/agent/flags", response_model=AgentFeatureFlagsRead)

View file

@ -649,13 +649,9 @@ async def list_composio_drive_folders(
""" """
List folders AND files in user's Google Drive via Composio. List folders AND files in user's Google Drive via Composio.
Uses the same GoogleDriveClient / list_folder_contents path as the native Uses Composio's Google Drive tool execution path so managed OAuth tokens
connector, with Composio-sourced credentials. This means auth errors do not need to be exposed through connected account state.
propagate identically (Google returns 401 exception auth_expired flag).
""" """
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
from app.utils.google_credentials import build_composio_credentials
if not ComposioService.is_enabled(): if not ComposioService.is_enabled():
raise HTTPException( raise HTTPException(
status_code=503, status_code=503,
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.", detail="Composio connected account not found. Please reconnect the connector.",
) )
credentials = build_composio_credentials(composio_connected_account_id) service = ComposioService()
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials) entity_id = f"surfsense_{user.id}"
items = []
page_token = None
error = None
items, error = await list_folder_contents(drive_client, parent_id=parent_id) while True:
page_items, next_token, page_error = await service.get_drive_files(
connected_account_id=composio_connected_account_id,
entity_id=entity_id,
folder_id=parent_id,
page_token=page_token,
page_size=100,
)
if page_error:
error = page_error
break
items.extend(page_items)
if not next_token:
break
page_token = next_token
for item in items:
item["isFolder"] = (
item.get("mimeType") == "application/vnd.google-apps.folder"
)
items.sort(
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
)
if error: if error:
error_lower = error.lower() error_lower = error.lower()

View file

@ -46,6 +46,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService, ImageGenRouterService,
is_image_gen_auto_mode, is_image_gen_auto_mode,
) )
from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user from app.users import current_active_user
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token from app.utils.signed_image_urls import verify_image_token
@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string( def _build_model_string(
provider: str, model_name: str, custom_provider: str | None provider: str, model_name: str, custom_provider: str | None
) -> str: ) -> str:
"""Build a litellm model string from provider + model_name.""" """Build a litellm model string from provider + model_name."""
if custom_provider: return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
async def _resolve_billing_for_image_gen( async def _resolve_billing_for_image_gen(
@ -187,12 +192,18 @@ async def _execute_image_generation(
if not cfg: if not cfg:
raise ValueError(f"Global image generation config {config_id} not found") raise ValueError(f"Global image generation config {config_id} not found")
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") cfg.get("provider", ""), cfg.get("custom_provider")
) )
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key") gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"): api_base = resolve_api_base(
gen_kwargs["api_base"] = cfg["api_base"] provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"): if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"] gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"): if cfg.get("litellm_params"):
@ -214,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg: if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found") raise ValueError(f"Image generation config {config_id} not found")
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider db_cfg.provider.value, db_cfg.custom_provider
) )
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base: api_base = resolve_api_base(
gen_kwargs["api_base"] = db_cfg.api_base provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version: if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params: if db_cfg.litellm_params:
@ -277,10 +294,12 @@ async def get_global_image_gen_configs(
# Auto mode currently treated as free until per-deployment # Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen). # billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free", "billing_tier": "free",
"is_premium": False,
} }
) )
for cfg in global_configs: for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append( safe_configs.append(
{ {
"id": cfg.get("id"), "id": cfg.get("id"),
@ -293,7 +312,11 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"), "billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"), "quota_reserve_micros": cfg.get("quota_reserve_micros"),
} }
) )

View file

@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate, NewLLMConfigUpdate,
) )
from app.services.llm_service import validate_llm_config from app.services.llm_service import validate_llm_config
from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user from app.users import current_active_user
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
There is no DB column for ``supports_image_input`` the value is
resolved at the API boundary from LiteLLM's authoritative model map
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
the response shape consistent across list / detail / create / update
endpoints without having to remember to set the field at every call
site.
"""
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
)
# ``model_validate`` runs the Pydantic conversion using the ORM
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
# then we layer the derived field on. ``model_copy(update=...)`` keeps
# the surface immutable from the caller's perspective.
base_read = NewLLMConfigRead.model_validate(config)
return base_read.model_copy(update={"supports_image_input": supports_image_input})
# ============================================================================= # =============================================================================
# Global Configs Routes # Global Configs Routes
# ============================================================================= # =============================================================================
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None, "seo_title": None,
"seo_description": None, "seo_description": None,
"quota_reserve_tokens": None, "quota_reserve_tokens": None,
# Auto routes across the configured pool, which usually
# includes at least one vision-capable deployment, so
# treat Auto as image-capable. The router itself will
# still pick a vision-capable deployment for messages
# carrying image_url blocks (LiteLLM Router falls back
# on ``404`` per its ``allowed_fails`` policy).
"supports_image_input": True,
} }
) )
# Add individual global configs # Add individual global configs
for cfg in global_configs: for cfg in global_configs:
# Capability resolution: explicit value (YAML override or OR
# `_supports_image_input(model)` payload baked in by the
# OpenRouter integration service) wins. Fall back to the
# LiteLLM-driven helper which default-allows on unknown so
# we don't hide vision-capable models that happen to lack a
# YAML annotation. The streaming task safety net is the
# only place a False ever blocks.
if "supports_image_input" in cfg:
supports_image_input = bool(cfg.get("supports_image_input"))
else:
cfg_litellm_params = cfg.get("litellm_params") or {}
cfg_base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
)
safe_config = { safe_config = {
"id": cfg.get("id"), "id": cfg.get("id"),
"name": cfg.get("name"), "name": cfg.get("name"),
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"), "seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"), "seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"supports_image_input": supports_image_input,
} }
safe_configs.append(safe_config) safe_configs.append(safe_config)
@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit() await session.commit()
await session.refresh(db_config) await session.refresh(db_config)
return db_config return _serialize_byok_config(db_config)
except HTTPException: except HTTPException:
raise raise
@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit) .limit(limit)
) )
return result.scalars().all() return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException: except HTTPException:
raise raise
@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space", "You don't have permission to view LLM configurations in this search space",
) )
return config return _serialize_byok_config(config)
except HTTPException: except HTTPException:
raise raise
@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit() await session.commit()
await session.refresh(config) await session.refresh(config)
return config return _serialize_byok_config(config)
except HTTPException: except HTTPException:
raise raise

View file

@ -85,10 +85,12 @@ async def get_global_vision_llm_configs(
# Auto mode treated as free until per-deployment billing-tier # Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity. # surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free", "billing_tier": "free",
"is_premium": False,
} }
) )
for cfg in global_configs: for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append( safe_configs.append(
{ {
"id": cfg.get("id"), "id": cfg.get("id"),
@ -101,7 +103,11 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None, "api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}), "litellm_params": cfg.get("litellm_params", {}),
"is_global": True, "is_global": True,
"billing_tier": cfg.get("billing_tier", "free"), "billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"), "input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"), "output_cost_per_token": cfg.get("output_cost_per_token"),

View file

@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel):
default="free", default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
) )
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_micros: int | None = Field( quota_reserve_micros: int | None = Field(
default=None, default=None,
description=( description=(

View file

@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
created_at: datetime created_at: datetime
search_space_id: int search_space_id: int
user_id: uuid.UUID user_id: uuid.UUID
# Capability flag derived at the API boundary (no DB column). Default
# True matches the conservative-allow stance — a BYOK row that the
# route forgot to augment is not pre-judged. The streaming-task
# safety net is the only place a False actually blocks a request.
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map "
"(``litellm.supports_vision``) — there is no DB column. "
"Default True is the conservative-allow stance for unknown / "
"unmapped models."
),
)
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
created_at: datetime created_at: datetime
search_space_id: int search_space_id: int
user_id: uuid.UUID user_id: uuid.UUID
# Capability flag derived at the API boundary (see NewLLMConfigRead).
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map. "
"Default True is the conservative-allow stance."
),
)
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
seo_title: str | None = None seo_title: str | None = None
seo_description: str | None = None seo_description: str | None = None
quota_reserve_tokens: int | None = None quota_reserve_tokens: int | None = None
supports_image_input: bool = Field(
default=True,
description=(
"Whether the model accepts image inputs (multimodal vision). "
"Derived server-side: OpenRouter dynamic configs use "
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
"authoritative model map (``litellm.supports_vision``). The "
"new-chat selector hints with a 'No image' badge when this is "
"False and there are pending image attachments. The streaming "
"task fails fast only when LiteLLM *explicitly* marks a model "
"as text-only — unknown / unmapped models default-allow."
),
)
# ============================================================================= # =============================================================================

View file

@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel):
default="free", default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
) )
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_tokens: int | None = Field( quota_reserve_tokens: int | None = Field(
default=None, default=None,
description=( description=(

View file

@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
_healthy_until.pop(int(config_id), None) _healthy_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]: def _cfg_supports_image_input(cfg: dict) -> bool:
"""True if the global cfg can accept image inputs.
Prefers the explicit ``supports_image_input`` flag (set by the YAML
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
a YAML entry whose flag was somehow stripped doesn't get wrongly
excluded. Default-allows on unknown the streaming-task safety net
is the actual block, not this filter.
"""
if "supports_image_input" in cfg:
return bool(cfg.get("supports_image_input"))
# Lazy import: provider_capabilities -> llm_config -> services chain;
# importing at module load would create an init-order cycle through
# ``app.config``.
from app.services.provider_capabilities import derive_supports_image_input
cfg_litellm_params = cfg.get("litellm_params") or {}
base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
return derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs. """Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts). in runtime cooldown (e.g. temporary 429 bursts).
When ``requires_image_input`` is True (image turn), additionally
filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request.
""" """
candidates = [ candidates = [
cfg cfg
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
if _is_usable_global_config(cfg) if _is_usable_global_config(cfg)
and not cfg.get("health_gated") and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0))) and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
and (not requires_image_input or _cfg_supports_image_input(cfg))
] ]
return sorted(candidates, key=lambda c: int(c.get("id", 0))) return sorted(candidates, key=lambda c: int(c.get("id", 0)))
@ -185,6 +220,15 @@ def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower() return str(cfg.get("billing_tier", "free")).lower()
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model."""
return (
_tier_of(cfg) == "premium"
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread. """Pick a config with quality-first ranking + deterministic spread.
@ -237,11 +281,20 @@ async def resolve_or_get_pinned_llm_config_id(
selected_llm_config_id: int, selected_llm_config_id: int,
force_repin_free: bool = False, force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None, exclude_config_ids: set[int] | None = None,
requires_image_input: bool = False,
) -> AutoPinResolution: ) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin. """Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns For non-auto selections, this function clears any existing pin and returns
the selected id as-is. the selected id as-is.
When ``requires_image_input`` is True (the current turn carries an
``image_url`` block), the candidate pool is filtered to vision-capable
cfgs and any existing pin that can't accept image input is treated as
invalid (force re-pin). If no vision-capable cfg is available the
function raises ``ValueError`` so the streaming task surfaces the same
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
silently routing the image to a text-only deployment.
""" """
thread = ( thread = (
( (
@ -274,14 +327,24 @@ async def resolve_or_get_pinned_llm_config_id(
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [ candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids c
for c in _global_candidates(requires_image_input=requires_image_input)
if int(c.get("id", 0)) not in excluded_ids
] ]
if not candidates: if not candidates:
if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError(
"No vision-capable global LLM configs are available for Auto mode"
)
raise ValueError("No usable global LLM configs are available for Auto mode") raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates} candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent # Reuse an existing valid pin without re-checking current quota (no silent
# tier switch), unless the caller explicitly requests a forced repin to free. # tier switch), unless the caller explicitly requests a forced repin to free
# *or* the turn requires image input but the pin can't handle it.
pinned_id = thread.pinned_llm_config_id pinned_id = thread.pinned_llm_config_id
if ( if (
not force_repin_free not force_repin_free
@ -311,6 +374,29 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=True, from_existing_pin=True,
) )
if pinned_id is not None: if pinned_id is not None:
# If the pin is *only* invalid because it can't handle the image
# turn (it's still a healthy, usable config in the broader pool),
# log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown.
if requires_image_input:
try:
pinned_global = next(
c
for c in config.GLOBAL_LLM_CONFIGS
if int(c.get("id", 0)) == int(pinned_id)
)
except StopIteration:
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info( logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id, thread_id,
@ -322,11 +408,19 @@ async def resolve_or_get_pinned_llm_config_id(
False if force_repin_free else await _is_premium_eligible(session, user_id) False if force_repin_free else await _is_premium_eligible(session, user_id)
) )
if premium_eligible: if premium_eligible:
eligible = candidates premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
preferred_premium = [
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
]
eligible = preferred_premium or premium_candidates
else: else:
eligible = [c for c in candidates if _tier_of(c) != "premium"] eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible: if not eligible:
if requires_image_input:
raise ValueError(
"Auto mode could not find a vision-capable LLM config for this user and quota state"
)
raise ValueError( raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state" "Auto mode could not find an eligible LLM config for this user and quota state"
) )

View file

@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B): KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` 1. **Session isolation.** ``billable_call`` takes no caller transaction.
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
insert each run inside their own ``shielded_async_session()``. This inside their own session context. Route callers use
guarantees that a quota commit/rollback can never accidentally flush or ``shielded_async_session()`` by default; Celery callers can provide a
roll back rows the caller has staged in the request's main session worker-loop-safe session factory. This guarantees that quota
(e.g. a freshly-created ``ImageGeneration`` row). commit/rollback can never accidentally flush or roll back rows the caller
has staged in its main session (e.g. a freshly-created
``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via 2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B):
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from typing import Any from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@ -58,6 +61,12 @@ from app.services.token_tracking_service import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AUDIT_TIMEOUT_SECONDS = 10.0
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
{"video_presentation_generation", "podcast_generation"}
)
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
class QuotaInsufficientError(Exception): class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable """Raised when ``TokenQuotaService.premium_reserve`` denies a billable
@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception):
) )
class BillingSettlementError(Exception):
"""Raised when a premium call completed but credit settlement failed."""
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
self.usage_type = usage_type
self.user_id = user_id
super().__init__(
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
)
async def _rollback_safely(session: AsyncSession) -> None:
rollback = getattr(session, "rollback", None)
if rollback is not None:
with suppress(Exception):
await rollback()
async def _record_audit_best_effort(
*,
session_factory: BillableSessionFactory,
usage_type: str,
search_space_id: int,
user_id: UUID,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int,
model_breakdown: dict[str, Any],
call_details: dict[str, Any] | None,
thread_id: int | None,
message_id: int | None,
audit_label: str,
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> None:
"""Persist a TokenUsage row without letting audit failure block callers.
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
audit insert or commit hangs, user-facing artifacts such as videos and
podcasts must still be able to transition to READY after settlement.
"""
audit_thread_id = (
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
)
async def _persist() -> None:
logger.info(
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
async with session_factory() as audit_session:
try:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=audit_thread_id,
message_id=message_id,
)
logger.info(
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
await audit_session.commit()
logger.info(
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
except BaseException:
await _rollback_safely(audit_session)
raise
try:
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
except TimeoutError:
logger.warning(
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
"timeout=%.1fs total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
timeout_seconds,
total_tokens,
cost_micros,
)
except Exception:
logger.exception(
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
@asynccontextmanager @asynccontextmanager
async def billable_call( async def billable_call(
*, *,
@ -101,6 +228,8 @@ async def billable_call(
thread_id: int | None = None, thread_id: int | None = None,
message_id: int | None = None, message_id: int | None = None,
call_details: dict[str, Any] | None = None, call_details: dict[str, Any] | None = None,
billable_session_factory: BillableSessionFactory | None = None,
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> AsyncIterator[TurnTokenAccumulator]: ) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call. """Wrap a single billable LLM/image call.
@ -124,6 +253,13 @@ async def billable_call(
thread_id, message_id: Optional FK columns on ``TokenUsage``. thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters) call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``. forwarded to ``record_token_usage``.
billable_session_factory: Optional async context factory used for
reserve/finalize/release/audit sessions. Defaults to
``shielded_async_session`` for route callers; Celery callers pass
a worker-loop-safe session factory.
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
Audit failure is best-effort and does not undo successful
settlement.
Yields: Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
@ -134,6 +270,7 @@ async def billable_call(
QuotaInsufficientError: when premium and ``premium_reserve`` denies. QuotaInsufficientError: when premium and ``premium_reserve`` denies.
""" """
is_premium = billing_tier == "premium" is_premium = billing_tier == "premium"
session_factory = billable_session_factory or shielded_async_session
async with scoped_turn() as acc: async with scoped_turn() as acc:
# ---------- Free path: just audit ------------------------------- # ---------- Free path: just audit -------------------------------
@ -143,10 +280,8 @@ async def billable_call(
finally: finally:
# Always audit, even on exception, so we capture cost when # Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later. # provider returns successfully but the caller raises later.
try: await _record_audit_best_effort(
async with shielded_async_session() as audit_session: session_factory=session_factory,
await record_token_usage(
audit_session,
usage_type=usage_type, usage_type=usage_type,
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id, user_id=user_id,
@ -158,14 +293,8 @@ async def billable_call(
call_details=call_details, call_details=call_details,
thread_id=thread_id, thread_id=thread_id,
message_id=message_id, message_id=message_id,
) audit_label="free",
await audit_session.commit() timeout_seconds=audit_timeout_seconds,
except Exception:
logger.exception(
"[billable_call] free-path audit insert failed for "
"usage_type=%s user_id=%s",
usage_type,
user_id,
) )
return return
@ -180,7 +309,7 @@ async def billable_call(
request_id = str(uuid4()) request_id = str(uuid4())
async with shielded_async_session() as quota_session: async with session_factory() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve( reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session, db_session=quota_session,
user_id=user_id, user_id=user_id,
@ -222,7 +351,7 @@ async def billable_call(
# from a downstream call, asyncio cancellation, etc.). We use # from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases. # BaseException so cancellation also releases.
try: try:
async with shielded_async_session() as quota_session: async with session_factory() as quota_session:
await TokenQuotaService.premium_release( await TokenQuotaService.premium_release(
db_session=quota_session, db_session=quota_session,
user_id=user_id, user_id=user_id,
@ -241,7 +370,16 @@ async def billable_call(
# ---------- Success: finalize + audit ---------------------------- # ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros actual_micros = acc.total_cost_micros
try: try:
async with shielded_async_session() as quota_session: logger.info(
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
"reserved=%d thread=%s",
user_id,
usage_type,
actual_micros,
reserve_micros,
thread_id,
)
async with session_factory() as quota_session:
final_result = await TokenQuotaService.premium_finalize( final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session, db_session=quota_session,
user_id=user_id, user_id=user_id,
@ -260,7 +398,7 @@ async def billable_call(
final_result.limit, final_result.limit,
final_result.remaining, final_result.remaining,
) )
except Exception: except Exception as finalize_exc:
# Last-ditch: if finalize itself fails, we must at least release # Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak. # so the reservation doesn't leak.
logger.exception( logger.exception(
@ -269,7 +407,7 @@ async def billable_call(
user_id, user_id,
) )
try: try:
async with shielded_async_session() as quota_session: async with session_factory() as quota_session:
await TokenQuotaService.premium_release( await TokenQuotaService.premium_release(
db_session=quota_session, db_session=quota_session,
user_id=user_id, user_id=user_id,
@ -281,11 +419,14 @@ async def billable_call(
"for user=%s", "for user=%s",
user_id, user_id,
) )
raise BillingSettlementError(
usage_type=usage_type,
user_id=user_id,
cause=finalize_exc,
) from finalize_exc
try: await _record_audit_best_effort(
async with shielded_async_session() as audit_session: session_factory=session_factory,
await record_token_usage(
audit_session,
usage_type=usage_type, usage_type=usage_type,
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id, user_id=user_id,
@ -297,14 +438,8 @@ async def billable_call(
call_details=call_details, call_details=call_details,
thread_id=thread_id, thread_id=thread_id,
message_id=message_id, message_id=message_id,
) audit_label="premium",
await audit_session.commit() timeout_seconds=audit_timeout_seconds,
except Exception:
logger.exception(
"[billable_call] premium-path audit insert failed for "
"usage_type=%s user_id=%s (debit was applied)",
usage_type,
user_id,
) )
@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space(
__all__ = [ __all__ = [
"BillingSettlementError",
"QuotaInsufficientError", "QuotaInsufficientError",
"_resolve_agent_billing_for_search_space", "_resolve_agent_billing_for_search_space",
"billable_call", "billable_call",

View file

@ -408,12 +408,37 @@ class ComposioService:
files = [] files = []
next_token = None next_token = None
if isinstance(data, dict): if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
# Try direct access first, then nested # Try direct access first, then nested
files = data.get("files", []) or data.get("data", {}).get("files", []) files = (
data.get("files", [])
or (
inner_data.get("files", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("files", [])
)
next_token = ( next_token = (
data.get("nextPageToken") data.get("nextPageToken")
or data.get("next_page_token") or data.get("next_page_token")
or data.get("data", {}).get("nextPageToken") or (
inner_data.get("nextPageToken")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("next_page_token")
if isinstance(inner_data, dict)
else None
)
or response_data.get("nextPageToken")
or response_data.get("next_page_token")
) )
elif isinstance(data, list): elif isinstance(data, list):
files = data files = data
@ -819,24 +844,61 @@ class ComposioService:
next_token = None next_token = None
result_size_estimate = None result_size_estimate = None
if isinstance(data, dict): if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
messages = ( messages = (
data.get("messages", []) data.get("messages", [])
or data.get("data", {}).get("messages", []) or (
inner_data.get("messages", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("messages", [])
or data.get("emails", []) or data.get("emails", [])
or (
inner_data.get("emails", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("emails", [])
) )
# Check for pagination token in various possible locations # Check for pagination token in various possible locations
next_token = ( next_token = (
data.get("nextPageToken") data.get("nextPageToken")
or data.get("next_page_token") or data.get("next_page_token")
or data.get("data", {}).get("nextPageToken") or (
or data.get("data", {}).get("next_page_token") inner_data.get("nextPageToken")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("next_page_token")
if isinstance(inner_data, dict)
else None
)
or response_data.get("nextPageToken")
or response_data.get("next_page_token")
) )
# Extract resultSizeEstimate if available (Gmail API provides this) # Extract resultSizeEstimate if available (Gmail API provides this)
result_size_estimate = ( result_size_estimate = (
data.get("resultSizeEstimate") data.get("resultSizeEstimate")
or data.get("result_size_estimate") or data.get("result_size_estimate")
or data.get("data", {}).get("resultSizeEstimate") or (
or data.get("data", {}).get("result_size_estimate") inner_data.get("resultSizeEstimate")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("result_size_estimate")
if isinstance(inner_data, dict)
else None
)
or response_data.get("resultSizeEstimate")
or response_data.get("result_size_estimate")
) )
elif isinstance(data, list): elif isinstance(data, list):
messages = data messages = data
@ -864,7 +926,7 @@ class ComposioService:
try: try:
result = await self.execute_tool( result = await self.execute_tool(
connected_account_id=connected_account_id, connected_account_id=connected_account_id,
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID", tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
params={"message_id": message_id}, # snake_case params={"message_id": message_id}, # snake_case
entity_id=entity_id, entity_id=entity_id,
) )
@ -872,7 +934,13 @@ class ComposioService:
if not result.get("success"): if not result.get("success"):
return None, result.get("error", "Unknown error") return None, result.get("error", "Unknown error")
return result.get("data"), None data = result.get("data")
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
return inner_data.get("response_data", inner_data), None
return data, None
except Exception as e: except Exception as e:
logger.error(f"Failed to get Gmail message detail: {e!s}") logger.error(f"Failed to get Gmail message detail: {e!s}")
@ -928,10 +996,27 @@ class ComposioService:
# Try different possible response structures # Try different possible response structures
events = [] events = []
if isinstance(data, dict): if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
events = ( events = (
data.get("items", []) data.get("items", [])
or data.get("data", {}).get("items", []) or (
inner_data.get("items", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("items", [])
or data.get("events", []) or data.get("events", [])
or (
inner_data.get("events", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("events", [])
) )
elif isinstance(data, list): elif isinstance(data, list):
events = data events = data

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import os
import time import time
from datetime import datetime from datetime import datetime
from threading import Lock
from typing import Any from typing import Any
import httpx import httpx
@ -2769,12 +2771,22 @@ class ConnectorService:
""" """
Get all available (enabled) connector types for a search space. Get all available (enabled) connector types for a search space.
Phase 1.4: results are cached per ``search_space_id`` for
:data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
identity the cached value is plain data, safe to share across
requests. Invalidate on connector add/update/delete via
:func:`invalidate_connector_discovery_cache`.
Args: Args:
search_space_id: The search space ID search_space_id: The search space ID
Returns: Returns:
List of SearchSourceConnectorType enums for enabled connectors List of SearchSourceConnectorType enums for enabled connectors
""" """
cached = _get_cached_connectors(search_space_id)
if cached is not None:
return list(cached)
query = ( query = (
select(SearchSourceConnector.connector_type) select(SearchSourceConnector.connector_type)
.filter( .filter(
@ -2784,8 +2796,9 @@ class ConnectorService:
) )
result = await self.session.execute(query) result = await self.session.execute(query)
connector_types = result.scalars().all() connector_types = list(result.scalars().all())
return list(connector_types) _set_cached_connectors(search_space_id, connector_types)
return connector_types
async def get_available_document_types( async def get_available_document_types(
self, self,
@ -2794,12 +2807,22 @@ class ConnectorService:
""" """
Get all document types that have at least one document in the search space. Get all document types that have at least one document in the search space.
Phase 1.4: cached per ``search_space_id`` for
:data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
:func:`invalidate_connector_discovery_cache` when a connector
finishes indexing new documents (or document types are otherwise
added/removed).
Args: Args:
search_space_id: The search space ID search_space_id: The search space ID
Returns: Returns:
List of document type strings that have documents indexed List of document type strings that have documents indexed
""" """
cached = _get_cached_doc_types(search_space_id)
if cached is not None:
return list(cached)
from sqlalchemy import distinct from sqlalchemy import distinct
from app.db import Document from app.db import Document
@ -2809,5 +2832,164 @@ class ConnectorService:
) )
result = await self.session.execute(query) result = await self.session.execute(query)
doc_types = result.scalars().all() doc_types = [str(dt) for dt in result.scalars().all()]
return [str(dt) for dt in doc_types] _set_cached_doc_types(search_space_id, doc_types)
return doc_types
# ---------------------------------------------------------------------------
# Connector / document-type discovery TTL cache (Phase 1.4)
# ---------------------------------------------------------------------------
#
# Both ``get_available_connectors`` and ``get_available_document_types`` are
# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
# hits Postgres and contributes to per-turn agent build latency. Their
# results change infrequently — only when the user adds/edits/removes a
# connector, or when an indexer commits a new document type. A short TTL
# cache (default 30s, env-tunable) collapses N concurrent calls into one
# DB roundtrip with bounded staleness.
#
# Invalidation: connector mutation routes (create / update / delete) call
# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
# entry for the affected space. Multi-replica deployments still pay one
# DB roundtrip per replica per TTL window, which is fine — staleness is
# bounded and the alternative (cross-replica fanout) is not worth the
# coupling here.
_DISCOVERY_TTL_SECONDS: float = float(
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
)
# Per-search-space caches. Keyed by ``search_space_id``; value is
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
# read-mostly workload, sub-microsecond contention.
_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
_cache_lock = Lock()
def _get_cached_connectors(
search_space_id: int,
) -> list[SearchSourceConnectorType] | None:
if _DISCOVERY_TTL_SECONDS <= 0:
return None
with _cache_lock:
entry = _connectors_cache.get(search_space_id)
if entry is None:
return None
expires_at, payload = entry
if time.monotonic() >= expires_at:
_connectors_cache.pop(search_space_id, None)
return None
return payload
def _set_cached_connectors(
search_space_id: int, payload: list[SearchSourceConnectorType]
) -> None:
if _DISCOVERY_TTL_SECONDS <= 0:
return
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
with _cache_lock:
_connectors_cache[search_space_id] = (expires_at, list(payload))
def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
if _DISCOVERY_TTL_SECONDS <= 0:
return None
with _cache_lock:
entry = _doc_types_cache.get(search_space_id)
if entry is None:
return None
expires_at, payload = entry
if time.monotonic() >= expires_at:
_doc_types_cache.pop(search_space_id, None)
return None
return payload
def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
if _DISCOVERY_TTL_SECONDS <= 0:
return
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
with _cache_lock:
_doc_types_cache[search_space_id] = (expires_at, list(payload))
def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
"""Drop cached discovery results for ``search_space_id`` (or all spaces).
Connector CRUD routes / indexer pipelines call this when they mutate
the rows backing :func:`ConnectorService.get_available_connectors` /
:func:`get_available_document_types`. ``None`` clears every space
useful in tests and on bulk imports.
"""
with _cache_lock:
if search_space_id is None:
_connectors_cache.clear()
_doc_types_cache.clear()
else:
_connectors_cache.pop(search_space_id, None)
_doc_types_cache.pop(search_space_id, None)
def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
with _cache_lock:
if search_space_id is None:
_connectors_cache.clear()
else:
_connectors_cache.pop(search_space_id, None)
def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
with _cache_lock:
if search_space_id is None:
_doc_types_cache.clear()
else:
_doc_types_cache.pop(search_space_id, None)
def _register_invalidation_listeners() -> None:
"""Wire SQLAlchemy ORM events so cache stays consistent automatically.
Listening on ``after_insert`` / ``after_update`` / ``after_delete``
means every successful INSERT/UPDATE/DELETE that goes through the ORM
invalidates the affected search space's cached discovery payload —
no need to sprinkle ``invalidate_*`` calls across 30+ connector
routes. Bulk operations that bypass the ORM (e.g.
``session.execute(insert(...))`` without a mapped object) still need
explicit invalidation; document indexers already commit through the
ORM so document-type discovery is covered.
"""
from sqlalchemy import event
# Imported here (not at module top) to avoid a circular import:
# app.services.connector_service is itself imported from app.db's
# ecosystem indirectly via several CRUD modules.
from app.db import Document, SearchSourceConnector
def _connector_changed(_mapper, _connection, target) -> None:
sid = getattr(target, "search_space_id", None)
if sid is not None:
_invalidate_connectors_only(int(sid))
def _document_changed(_mapper, _connection, target) -> None:
sid = getattr(target, "search_space_id", None)
if sid is not None:
_invalidate_doc_types_only(int(sid))
for evt in ("after_insert", "after_update", "after_delete"):
event.listen(SearchSourceConnector, evt, _connector_changed)
event.listen(Document, evt, _document_changed)
try:
_register_invalidation_listeners()
except Exception: # pragma: no cover - defensive; never block module import
import logging as _logging
_logging.getLogger(__name__).exception(
"Failed to register connector discovery cache invalidation listeners; "
"stale cache risk: explicit invalidate_connector_discovery_cache calls "
"may be required."
)

View file

@ -17,7 +17,7 @@ from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
) )
from app.utils.google_credentials import build_composio_credentials from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,14 +78,49 @@ class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession): def __init__(self, db_session: AsyncSession):
self._db_session = db_session self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
if ( return (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): )
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
return build_composio_credentials(cca_id) raise ValueError("Composio connected_account_id not found")
return cca_id
def _unwrap_composio_data(self, data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def _execute_composio_gmail_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
result = await ComposioService().execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return self._unwrap_composio_data(result.get("data")), None
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if self._is_composio_connector(connector):
raise ValueError(
"Composio Gmail connectors must use Composio tool execution"
)
config_data = dict(connector.config) config_data = dict(connector.config)
@ -139,6 +174,12 @@ class GmailToolMetadataService:
if not connector: if not connector:
return True return True
if self._is_composio_connector(connector):
_profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
return bool(error)
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds) service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
@ -221,6 +262,13 @@ class GmailToolMetadataService:
) )
connector = result.scalar_one_or_none() connector = result.scalar_one_or_none()
if connector: if connector:
if self._is_composio_connector(connector):
profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
if error:
raise RuntimeError(error)
else:
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds) service = build("gmail", "v1", credentials=creds)
profile = await asyncio.get_event_loop().run_in_executor( profile = await asyncio.get_event_loop().run_in_executor(
@ -298,6 +346,23 @@ class GmailToolMetadataService:
Returns ``None`` on any failure so callers can degrade gracefully. Returns ``None`` on any failure so callers can degrade gracefully.
""" """
try: try:
if self._is_composio_connector(connector):
if not draft_id:
draft_id = await self._find_composio_draft_id(connector, message_id)
if not draft_id:
return None
draft, error = await self._execute_composio_gmail_tool(
connector,
"GMAIL_GET_DRAFT",
{"user_id": "me", "draft_id": draft_id, "format": "full"},
)
if error or not isinstance(draft, dict):
return None
payload = draft.get("message", {}).get("payload", {})
return self._extract_body_from_payload(payload)
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds) service = build("gmail", "v1", credentials=creds)
@ -326,6 +391,33 @@ class GmailToolMetadataService:
) )
return None return None
async def _find_composio_draft_id(
self, connector: SearchSourceConnector, message_id: str
) -> str | None:
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None
async def _find_draft_id(self, service: Any, message_id: str) -> str | None: async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list.""" """Resolve a draft ID from its message ID by scanning drafts.list."""
try: try:

View file

@ -14,6 +14,7 @@ from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
) )
from app.services.composio_service import ComposioService
from app.utils.document_converters import ( from app.utils.document_converters import (
create_document_chunks, create_document_chunks,
embed_text, embed_text,
@ -21,7 +22,6 @@ from app.utils.document_converters import (
generate_document_summary, generate_document_summary,
generate_unique_identifier_hash, generate_unique_identifier_hash,
) )
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -203,15 +203,38 @@ class GoogleCalendarKBSyncService:
logger.warning("Document %s not found in KB", document_id) logger.warning("Document %s not found in KB", document_id)
return {"status": "not_indexed"} return {"status": "not_indexed"}
calendar_id = (document.document_metadata or {}).get(
"calendar_id"
) or "primary"
connector = await self._get_connector(connector_id)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_EVENTS_GET",
params={"calendar_id": calendar_id, "event_id": event_id},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get("error", "Unknown Composio Calendar error")
)
live_event = composio_result.get("data", {})
if isinstance(live_event, dict):
live_event = live_event.get("data", live_event)
if isinstance(live_event, dict):
live_event = live_event.get("response_data", live_event)
else:
creds = await self._build_credentials_for_connector(connector_id) creds = await self._build_credentials_for_connector(connector_id)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
service = await loop.run_in_executor( service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
) )
calendar_id = (document.document_metadata or {}).get(
"calendar_id"
) or "primary"
live_event = await loop.run_in_executor( live_event = await loop.run_in_executor(
None, None,
lambda: ( lambda: (
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
await self.db_session.rollback() await self.db_session.rollback()
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials: async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
result = await self.db_session.execute( result = await self.db_session.execute(
select(SearchSourceConnector).where( select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id SearchSourceConnector.id == connector_id
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
connector = result.scalar_one_or_none() connector = result.scalar_one_or_none()
if not connector: if not connector:
raise ValueError(f"Connector {connector_id} not found") raise ValueError(f"Connector {connector_id} not found")
return connector
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
connector = await self._get_connector(connector_id)
if ( if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): ):
cca_id = connector.config.get("composio_connected_account_id") raise ValueError(
if cca_id: "Composio Calendar connectors must use Composio tool execution"
return build_composio_credentials(cca_id) )
raise ValueError("Composio connected_account_id not found")
config_data = dict(connector.config) config_data = dict(connector.config)

View file

@ -16,7 +16,7 @@ from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
) )
from app.utils.google_credentials import build_composio_credentials from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
def __init__(self, db_session: AsyncSession): def __init__(self, db_session: AsyncSession):
self._db_session = db_session self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials: def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
if ( return (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): )
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
return build_composio_credentials(cca_id)
raise ValueError("Composio connected_account_id not found") raise ValueError("Composio connected_account_id not found")
return cca_id
async def _execute_composio_calendar_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict,
) -> tuple[dict | list | None, str | None]:
service = ComposioService()
result = await service.execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Calendar error")
data = result.get("data")
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner), None
return inner, None
return data, None
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if self._is_composio_connector(connector):
raise ValueError(
"Composio Calendar connectors must use Composio tool execution"
)
config_data = dict(connector.config) config_data = dict(connector.config)
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
if not connector: if not connector:
return True return True
if self._is_composio_connector(connector):
_data, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_GET_CALENDAR",
{"calendar_id": "primary"},
)
return bool(error)
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(
@ -255,6 +297,23 @@ class GoogleCalendarToolMetadataService:
timezone_str = "" timezone_str = ""
if connector: if connector:
try: try:
if self._is_composio_connector(connector):
cal_list, cal_error = await self._execute_composio_calendar_tool(
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
)
if cal_error:
raise RuntimeError(cal_error)
(
settings,
settings_error,
) = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_SETTINGS_GET",
{"setting": "timezone"},
)
if not settings_error and isinstance(settings, dict):
timezone_str = settings.get("value", "")
else:
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
service = await loop.run_in_executor( service = await loop.run_in_executor(
@ -264,7 +323,22 @@ class GoogleCalendarToolMetadataService:
cal_list = await loop.run_in_executor( cal_list = await loop.run_in_executor(
None, lambda: service.calendarList().list().execute() None, lambda: service.calendarList().list().execute()
) )
for cal in cal_list.get("items", []):
tz_setting = await loop.run_in_executor(
None,
lambda: service.settings().get(setting="timezone").execute(),
)
timezone_str = tz_setting.get("value", "")
calendar_items = []
if isinstance(cal_list, dict):
calendar_items = (
cal_list.get("items") or cal_list.get("calendars") or []
)
elif isinstance(cal_list, list):
calendar_items = cal_list
for cal in calendar_items:
calendars.append( calendars.append(
{ {
"id": cal.get("id", ""), "id": cal.get("id", ""),
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
"primary": cal.get("primary", False), "primary": cal.get("primary", False),
} }
) )
tz_setting = await loop.run_in_executor(
None,
lambda: service.settings().get(setting="timezone").execute(),
)
timezone_str = tz_setting.get("value", "")
except Exception: except Exception:
logger.warning( logger.warning(
"Failed to fetch calendars/timezone for connector %s", "Failed to fetch calendars/timezone for connector %s",
@ -321,12 +389,21 @@ class GoogleCalendarToolMetadataService:
event_dict = event.to_dict() event_dict = event.to_dict()
try: try:
calendar_id = event.calendar_id or "primary"
if self._is_composio_connector(connector):
live_event, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_EVENTS_GET",
{"calendar_id": calendar_id, "event_id": event.event_id},
)
if error:
raise RuntimeError(error)
else:
creds = await self._build_credentials(connector) creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
service = await loop.run_in_executor( service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
) )
calendar_id = event.calendar_id or "primary"
live_event = await loop.run_in_executor( live_event = await loop.run_in_executor(
None, None,
lambda: ( lambda: (
@ -376,14 +453,32 @@ class GoogleCalendarToolMetadataService:
) -> dict: ) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref) resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved: if not resolved:
live_resolved = await self._resolve_live_event(
search_space_id, user_id, event_ref
)
if not live_resolved:
return { return {
"error": ( "error": (
f"Event '{event_ref}' not found in your indexed Google Calendar events. " f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, " "This could mean: (1) the event doesn't exist, "
"or (3) the event name is different." "(2) the event name is different, or "
"(3) the connected calendar account cannot access it."
) )
} }
connector, live_event = live_resolved
account = GoogleCalendarAccount.from_connector(connector)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"account": acc_dict,
"event": self._event_dict_from_live_event(live_event),
}
document, connector = resolved document, connector = resolved
account = GoogleCalendarAccount.from_connector(connector) account = GoogleCalendarAccount.from_connector(connector)
event = GoogleCalendarEvent.from_document(document) event = GoogleCalendarEvent.from_document(document)
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
if row: if row:
return row[0], row[1] return row[0], row[1]
return None return None
async def _resolve_live_event(
self, search_space_id: int, user_id: str, event_ref: str
) -> tuple[SearchSourceConnector, dict] | None:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
for connector in connectors:
try:
events = await self._search_live_events(connector, event_ref)
except Exception:
logger.warning(
"Failed to search live calendar events for connector %s",
connector.id,
exc_info=True,
)
continue
if not events:
continue
normalized_ref = event_ref.strip().lower()
exact_match = next(
(
event
for event in events
if event.get("summary", "").strip().lower() == normalized_ref
),
None,
)
return connector, exact_match or events[0]
return None
async def _search_live_events(
self, connector: SearchSourceConnector, event_ref: str
) -> list[dict]:
if self._is_composio_connector(connector):
data, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_EVENTS_LIST",
{
"calendar_id": "primary",
"q": event_ref,
"max_results": 10,
"single_events": True,
"order_by": "startTime",
},
)
if error:
raise RuntimeError(error)
if isinstance(data, dict):
return data.get("items") or data.get("events") or []
return data if isinstance(data, list) else []
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
response = await loop.run_in_executor(
None,
lambda: (
service.events()
.list(
calendarId="primary",
q=event_ref,
maxResults=10,
singleEvents=True,
orderBy="startTime",
)
.execute()
),
)
return response.get("items", [])
def _event_dict_from_live_event(self, event: dict) -> dict:
start_data = event.get("start", {})
end_data = event.get("end", {})
return {
"event_id": event.get("id", ""),
"summary": event.get("summary", "No Title"),
"start": start_data.get("dateTime", start_data.get("date", "")),
"end": end_data.get("dateTime", end_data.get("date", "")),
"description": event.get("description", ""),
"location": event.get("location", ""),
"attendees": [
{
"email": attendee.get("email", ""),
"responseStatus": attendee.get("responseStatus", ""),
}
for attendee in event.get("attendees", [])
],
"calendar_id": event.get("calendarId", "primary"),
"document_id": None,
"indexed_at": None,
}

View file

@ -13,7 +13,7 @@ from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
) )
from app.utils.google_credentials import build_composio_credentials from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
def __init__(self, db_session: AsyncSession): def __init__(self, db_session: AsyncSession):
self._db_session = db_session self._db_session = db_session
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
return cca_id
async def _execute_composio_drive_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict,
) -> tuple[dict | list | None, str | None]:
result = await ComposioService().execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Drive error")
data = result.get("data")
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner), None
return inner, None
return data, None
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict: async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_google_drive_accounts(search_space_id, user_id) accounts = await self._get_google_drive_accounts(search_space_id, user_id)
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
if not connector: if not connector:
return True return True
pre_built_creds = None if self._is_composio_connector(connector):
if ( _data, error = await self._execute_composio_drive_tool(
connector.connector_type connector,
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR "GOOGLEDRIVE_LIST_FILES",
): {
cca_id = connector.config.get("composio_connected_account_id") "q": "trashed = false",
if cca_id: "page_size": 1,
pre_built_creds = build_composio_credentials(cca_id) "fields": "files(id)",
},
)
return bool(error)
client = GoogleDriveClient( client = GoogleDriveClient(
session=self._db_session, session=self._db_session,
connector_id=connector_id, connector_id=connector_id,
credentials=pre_built_creds,
) )
await client.list_files( await client.list_files(
query="trashed = false", page_size=1, fields="files(id)" query="trashed = false", page_size=1, fields="files(id)"
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
parent_folders[connector_id] = [] parent_folders[connector_id] = []
continue continue
pre_built_creds = None if self._is_composio_connector(connector):
if ( data, error = await self._execute_composio_drive_tool(
connector.connector_type connector,
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR "GOOGLEDRIVE_LIST_FILES",
): {
cca_id = connector.config.get("composio_connected_account_id") "q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
if cca_id: "fields": "files(id,name)",
pre_built_creds = build_composio_credentials(cca_id) "page_size": 50,
},
)
if error:
logger.warning(
"Failed to list folders for connector %s: %s",
connector_id,
error,
)
parent_folders[connector_id] = []
continue
folders = []
if isinstance(data, dict):
folders = data.get("files", [])
elif isinstance(data, list):
folders = data
parent_folders[connector_id] = [
{"folder_id": f["id"], "name": f["name"]}
for f in folders
if f.get("id") and f.get("name")
]
continue
client = GoogleDriveClient( client = GoogleDriveClient(
session=self._db_session, session=self._db_session,
connector_id=connector_id, connector_id=connector_id,
credentials=pre_built_creds,
) )
folders, _, error = await client.list_files( folders, _, error = await client.list_files(

View file

@ -20,6 +20,8 @@ from typing import Any
from litellm import Router from litellm import Router
from litellm.utils import ImageResponse from litellm.utils import ImageResponse
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing # Special ID for Auto mode - uses router for load balancing
@ -152,10 +154,10 @@ class ImageGenRouterService:
return None return None
# Build model string # Build model string
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
else:
provider = config.get("provider", "").upper() provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
else:
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}" model_string = f"{provider_prefix}/{config['model_name']}"
@ -165,9 +167,16 @@ class ImageGenRouterService:
"api_key": config.get("api_key"), "api_key": config.get("api_key"),
} }
# Add optional api_base # Resolve ``api_base`` so deployments don't silently inherit
if config.get("api_base"): # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
litellm_params["api_base"] = config["api_base"] # the wrong provider (see ``provider_api_base`` docstring).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
# Add api_version (required for Azure) # Add api_version (required for Azure)
if config.get("api_version"): if config.get("api_version"):

View file

@ -140,8 +140,6 @@ PROVIDER_MAP = {
# 404-ing against an inherited Azure endpoint). Re-exported here for # 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import. # backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402 from app.services.provider_api_base import ( # noqa: E402
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base, resolve_api_base,
) )

View file

@ -16,6 +16,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm, get_auto_mode_llm,
is_auto_mode, is_auto_mode,
) )
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import token_tracker from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters # Configure litellm to automatically drop unsupported parameters
@ -556,22 +557,26 @@ async def get_vision_llm(
return None return None
if global_cfg.get("custom_provider"): if global_cfg.get("custom_provider"):
model_string = ( provider_prefix = global_cfg["custom_provider"]
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}" model_string = f"{provider_prefix}/{global_cfg['model_name']}"
)
else: else:
prefix = VISION_PROVIDER_MAP.get( provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(), global_cfg["provider"].upper(),
global_cfg["provider"].lower(), global_cfg["provider"].lower(),
) )
model_string = f"{prefix}/{global_cfg['model_name']}" model_string = f"{provider_prefix}/{global_cfg['model_name']}"
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,
"api_key": global_cfg["api_key"], "api_key": global_cfg["api_key"],
} }
if global_cfg.get("api_base"): api_base = resolve_api_base(
litellm_kwargs["api_base"] = global_cfg["api_base"] provider=global_cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=global_cfg.get("api_base"),
)
if api_base:
litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"): if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"]) litellm_kwargs.update(global_cfg["litellm_params"])
@ -606,20 +611,26 @@ async def get_vision_llm(
return None return None
if vision_cfg.custom_provider: if vision_cfg.custom_provider:
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}" provider_prefix = vision_cfg.custom_provider
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else: else:
prefix = VISION_PROVIDER_MAP.get( provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(), vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(), vision_cfg.provider.value.lower(),
) )
model_string = f"{prefix}/{vision_cfg.model_name}" model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,
"api_key": vision_cfg.api_key, "api_key": vision_cfg.api_key,
} }
if vision_cfg.api_base: api_base = resolve_api_base(
litellm_kwargs["api_base"] = vision_cfg.api_base provider=vision_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=vision_cfg.api_base,
)
if api_base:
litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params: if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params) litellm_kwargs.update(vision_cfg.litellm_params)

View file

@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool:
return "image" in input_mods and "text" in output_mods return "image" in input_mods and "text" in output_mods
def _supports_image_input(model: dict) -> bool:
"""Return True if the model accepts ``image`` in its input modalities.
Differs from :func:`_is_vision_input_model` in that it does NOT
require text output chat-tab models always emit text already (the
chat catalog filters by ``_is_text_output_model``), so the only
extra capability we need to track per chat config is whether the
model can ingest user-attached images. The chat selector and the
streaming task both key off this flag to prevent hitting an
OpenRouter 404 ``"No endpoints found that support image input"``
when the user uploads an image and selects a text-only model
(DeepSeek V3, Llama 3.x base, etc.).
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
return "image" in input_mods
def _supports_tool_calling(model: dict) -> bool: def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling.""" """Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or [] supported = model.get("supported_parameters") or []
@ -321,6 +339,13 @@ def _generate_configs(
# account-wide quota, so per-deployment routing can't spread load # account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster. # there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium", "router_pool_eligible": tier == "premium",
# Capability flag derived from ``architecture.input_modalities``.
# Read by the new-chat selector to dim image-incompatible models
# when the user has pending image attachments, and by
# ``stream_new_chat`` as a fail-fast safety net before the
# OpenRouter request would otherwise 404 with
# ``"No endpoints found that support image input"``.
"supports_image_input": _supports_image_input(model),
_OPENROUTER_DYNAMIC_MARKER: True, _OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised # Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next # to the static score and gets re-blended with health on the next
@ -398,7 +423,12 @@ def _generate_image_gen_configs(
"provider": "OPENROUTER", "provider": "OPENROUTER",
"model_name": model_id, "model_name": model_id,
"api_key": api_key, "api_key": api_key,
"api_base": "", # Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
# ``image_generation/transformation`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None, "api_version": None,
"rpm": free_rpm if tier == "free" else rpm, "rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params), "litellm_params": dict(litellm_params),
@ -477,7 +507,11 @@ def _generate_vision_llm_configs(
"provider": "OPENROUTER", "provider": "OPENROUTER",
"model_name": model_id, "model_name": model_id,
"api_key": api_key, "api_key": api_key,
"api_base": "", # Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None, "api_version": None,
"rpm": free_rpm if tier == "free" else rpm, "rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm, "tpm": free_tpm if tier == "free" else tpm,

View file

@ -17,7 +17,6 @@ source of truth without an inter-service circular import.
from __future__ import annotations from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = { PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1", "openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1", "groq": "https://api.groq.com/openai/v1",

View file

@ -0,0 +1,280 @@
"""Capability resolution shared by chat / image / vision call sites.
Why this exists
---------------
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
single, authoritative answer to one question: *can this chat config accept
``image_url`` content blocks?* Without it, the new-chat selector can't badge
incompatible models and the streaming task can't fail fast with a friendly
error before sending an image to a text-only provider.
Two functions, two intents:
- :func:`derive_supports_image_input` best-effort *True* for catalog and
UI surfacing. Default-allow: an unknown / unmapped model is treated as
capable so we never lock the user out of a freshly added or
third-party-hosted vision model.
- :func:`is_known_text_only_chat_model` strict opt-out for the streaming
task's safety net. Returns True only when LiteLLM's model map *explicitly*
sets ``supports_vision=False`` (or its bare-name variant does). Anything
else missing key, lookup exception, ``supports_vision=True`` returns
False so the request flows through to the provider.
Implementation rule: only public LiteLLM symbols
------------------------------------------------
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
across releases. The private ``_is_explicitly_disabled_factory`` and
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
can't silently break us.
Why the previous round's strict YAML opt-in flag failed
-------------------------------------------------------
``supports_image_input: false`` was the YAML loader's setdefault. Operators
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
YAML chat model including vision-capable GPT-5.x and GPT-4o resolved to
False and the streaming gate rejected every image turn. Sourcing capability
from LiteLLM's authoritative model map (which already says
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
import litellm
logger = logging.getLogger(__name__)
# Provider-name → LiteLLM model-prefix map.
#
# Owned here because ``app.services.provider_capabilities`` is the
# only edge that's safe to call from ``app.config``'s YAML loader at
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
# this constant under the historical ``PROVIDER_MAP`` name; placing the
# map there directly would re-introduce the
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
# app.config`` cycle that prompted the move.
_PROVIDER_PREFIX_MAP: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _candidate_model_strings(
*,
provider: str | None,
model_name: str | None,
base_model: str | None,
custom_provider: str | None,
) -> list[tuple[str, str | None]]:
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
LiteLLM's capability lookup is keyed by ``model`` + (optional)
``custom_llm_provider``. Different config sources give us different
levels of detail, so we try the most-specific keys first and fall back
to bare model names so unannotated entries (e.g. an Azure deployment
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
map. Order matters the first lookup that returns a definitive answer
wins for both helpers.
"""
candidates: list[tuple[str, str | None]] = []
seen: set[tuple[str, str | None]] = set()
def _add(model: str | None, llm_provider: str | None) -> None:
if not model:
return
key = (model, llm_provider)
if key in seen:
return
seen.add(key)
candidates.append(key)
provider_prefix: str | None = None
if provider:
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
if custom_provider:
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
provider_prefix = custom_provider
primary_model = base_model or model_name
bare_model = model_name
# Most-specific first: provider-prefixed identifier with explicit
# custom_llm_provider so LiteLLM won't have to guess the provider via
# ``get_llm_provider``.
if primary_model and provider_prefix:
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
if "/" in primary_model:
_add(primary_model, provider_prefix)
else:
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
# Bare base_model (or model_name) with provider hint — handles entries
# the upstream map keys without a provider prefix (most ``gpt-*`` and
# ``claude-*`` entries do this).
if primary_model:
_add(primary_model, provider_prefix)
# Fallback to model_name when base_model differs (e.g. an Azure
# deployment whose model_name is the deployment id but base_model is the
# canonical OpenAI sku).
if bare_model and bare_model != primary_model:
if provider_prefix and "/" not in bare_model:
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
_add(bare_model, provider_prefix)
_add(bare_model, None)
return candidates
def derive_supports_image_input(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
openrouter_input_modalities: Iterable[str] | None = None,
) -> bool:
"""Best-effort capability flag for the new-chat selector and catalog.
Resolution order (first definitive answer wins):
1. ``openrouter_input_modalities`` (when provided as a non-empty
iterable). OpenRouter exposes ``architecture.input_modalities`` per
model and that's the authoritative source for OR dynamic configs.
2. ``litellm.supports_vision`` against each candidate identifier from
:func:`_candidate_model_strings`. Returns True as soon as any
candidate confirms vision support.
3. Default ``True`` the conservative-allow stance. An unknown /
newly-added / third-party-hosted model is *not* pre-judged. The
streaming safety net (:func:`is_known_text_only_chat_model`) is the
only place a False ever blocks; everywhere else, a False here would
just hide a usable model from the user.
Returns:
True if the model can plausibly accept image input, False only when
OpenRouter explicitly says it can't.
"""
if openrouter_input_modalities is not None:
modalities = list(openrouter_input_modalities)
if modalities:
return "image" in modalities
# Empty list explicitly published by OR — treat as "no image".
return False
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
if litellm.supports_vision(
model=model_string, custom_llm_provider=custom_llm_provider
):
return True
except Exception as exc:
logger.debug(
"litellm.supports_vision raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
return True
def is_known_text_only_chat_model(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
) -> bool:
"""Strict opt-out probe for the streaming-task safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False`` for at least one candidate identifier. Missing
key, lookup exception, or ``supports_vision=True`` all return False so
the streaming task lets the request through. This is the inverse-default
of :func:`derive_supports_image_input`.
Why two functions
-----------------
The selector wants "show me everything that's plausibly capable"
default-allow. The safety net wants "block only when I'm certain it
can't" — default-pass. Mixing the two intents in a single function
leads to the regression we're fixing here.
"""
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
info = litellm.get_model_info(
model=model_string, custom_llm_provider=custom_llm_provider
)
except Exception as exc:
logger.debug(
"litellm.get_model_info raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
# may be missing, None, True, or False. We only fire on explicit
# False — None / missing / True all mean "don't block".
try:
value = info.get("supports_vision") # type: ignore[union-attr]
except AttributeError:
value = None
if value is False:
return True
return False
__all__ = [
"derive_supports_image_input",
"is_known_text_only_chat_model",
]

View file

@ -1,10 +1,25 @@
"""Celery tasks package.""" """Celery tasks package.
Also hosts the small helpers every async celery task should use to
spin up its event loop. See :func:`run_async_celery_task` for the
canonical pattern.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable, Callable
from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from app.config import config from app.config import config
logger = logging.getLogger(__name__)
_celery_engine = None _celery_engine = None
_celery_session_maker = None _celery_session_maker = None
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False _celery_engine, expire_on_commit=False
) )
return _celery_session_maker return _celery_session_maker
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
"""Drop the shared ``app.db.engine`` connection pool synchronously.
The shared engine (used by ``shielded_async_session`` and most
routes / services) is a module-level singleton with a real pool.
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
connections cache a reference to whichever loop opened them. When
a subsequent task's loop pulls a stale connection from the pool,
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
AttributeError: 'NoneType' object has no attribute 'send'
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
self._write_fut = self._loop._proactor.send(self._sock, data)
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
coroutine that can never run because its loop is gone.
Disposing the engine forces the pool to drop every cached
connection so the next checkout opens a fresh one on the current
loop. Safe to call from a task's finally block; failure is logged
but never propagated.
"""
try:
from app.db import engine as shared_engine
loop.run_until_complete(shared_engine.dispose())
except Exception:
logger.warning("Shared DB engine dispose() failed", exc_info=True)
T = TypeVar("T")
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
"""Run an async coroutine inside a fresh event loop with proper
DB-engine cleanup.
This is the canonical entry point for every async celery task.
It performs three responsibilities that were previously copy-pasted
(incorrectly) across each task module:
1. Create a fresh ``asyncio`` loop and install it on the current
thread (celery's ``--pool=solo`` runs every task on the main
thread, but other pool types don't).
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
any stale connections left over from a previous task's loop
are dropped defends against tasks that crashed without
cleaning up.
3. Dispose the shared engine AFTER the task runs so the
connections we opened on this loop are released before the
loop closes (avoids ``coroutine 'Connection._cancel' was
never awaited`` warnings and the next-task hang).
Use as::
@celery_app.task(name="my_task", bind=True)
def my_task(self, *args):
return run_async_celery_task(lambda: _my_task_impl(*args))
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Defense-in-depth: prior task may have crashed before
# disposing. Idempotent — no-op if pool is already empty.
_dispose_shared_db_engine(loop)
return loop.run_until_complete(coro_factory())
finally:
# Drop any connections this task opened so they don't leak
# into the next task's loop.
_dispose_shared_db_engine(loop)
with contextlib.suppress(Exception):
loop.run_until_complete(loop.shutdown_asyncgens())
with contextlib.suppress(Exception):
asyncio.set_event_loop(None)
loop.close()
__all__ = [
"get_celery_session_maker",
"run_async_celery_task",
]

View file

@ -4,7 +4,7 @@ import logging
import traceback import traceback
from app.celery_app import celery_app from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Notion pages.""" """Celery task to index Notion pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete( return run_async_celery_task(
_index_notion_pages( lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
except Exception as e: except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id) _handle_greenlet_error(e, "index_notion_pages", connector_id)
raise raise
finally:
loop.close()
async def _index_notion_pages( async def _index_notion_pages(
@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str, end_date: str,
): ):
"""Celery task to index GitHub repositories.""" """Celery task to index GitHub repositories."""
import asyncio return run_async_celery_task(
lambda: _index_github_repos(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_github_repos( async def _index_github_repos(
@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Confluence pages.""" """Celery task to index Confluence pages."""
import asyncio return run_async_celery_task(
lambda: _index_confluence_pages(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_confluence_pages( async def _index_confluence_pages(
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Google Calendar events.""" """Celery task to index Google Calendar events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete( return run_async_celery_task(
_index_google_calendar_events( lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
except Exception as e: except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id) _handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise raise
finally:
loop.close()
async def _index_google_calendar_events( async def _index_google_calendar_events(
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Google Gmail messages.""" """Celery task to index Google Gmail messages."""
import asyncio return run_async_celery_task(
lambda: _index_google_gmail_messages(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_google_gmail_messages( async def _index_google_gmail_messages(
@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options' items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
): ):
"""Celery task to index Google Drive folders and files.""" """Celery task to index Google Drive folders and files."""
import asyncio return run_async_celery_task(
lambda: _index_google_drive_files(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_drive_files(
connector_id, connector_id,
search_space_id, search_space_id,
user_id, user_id,
items_dict, items_dict,
) )
) )
finally:
loop.close()
async def _index_google_drive_files( async def _index_google_drive_files(
@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict, items_dict: dict,
): ):
"""Celery task to index OneDrive folders and files.""" """Celery task to index OneDrive folders and files."""
import asyncio return run_async_celery_task(
lambda: _index_onedrive_files(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_onedrive_files(
connector_id, connector_id,
search_space_id, search_space_id,
user_id, user_id,
items_dict, items_dict,
) )
) )
finally:
loop.close()
async def _index_onedrive_files( async def _index_onedrive_files(
@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict, items_dict: dict,
): ):
"""Celery task to index Dropbox folders and files.""" """Celery task to index Dropbox folders and files."""
import asyncio return run_async_celery_task(
lambda: _index_dropbox_files(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_dropbox_files(
connector_id, connector_id,
search_space_id, search_space_id,
user_id, user_id,
items_dict, items_dict,
) )
) )
finally:
loop.close()
async def _index_dropbox_files( async def _index_dropbox_files(
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Elasticsearch documents.""" """Celery task to index Elasticsearch documents."""
import asyncio return run_async_celery_task(
lambda: _index_elasticsearch_documents(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_elasticsearch_documents( async def _index_elasticsearch_documents(
@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str, end_date: str,
): ):
"""Celery task to index Web page Urls.""" """Celery task to index Web page Urls."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete( return run_async_celery_task(
_index_crawled_urls( lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
except Exception as e: except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id) _handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise raise
finally:
loop.close()
async def _index_crawled_urls( async def _index_crawled_urls(
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str, end_date: str,
): ):
"""Celery task to index BookStack pages.""" """Celery task to index BookStack pages."""
import asyncio return run_async_celery_task(
lambda: _index_bookstack_pages(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_bookstack_pages( async def _index_bookstack_pages(
@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None, end_date: str | None,
): ):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio).""" """Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
import asyncio return run_async_celery_task(
lambda: _index_composio_connector(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date connector_id, search_space_id, user_id, start_date, end_date
) )
) )
finally:
loop.close()
async def _index_composio_connector( async def _index_composio_connector(

View file

@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex document_id: ID of document to reindex
user_id: ID of user who edited the document user_id: ID of user who edited the document
""" """
import asyncio return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reindex_document(document_id, user_id))
finally:
loop.close()
async def _reindex_document(document_id: int, user_id: str): async def _reindex_document(document_id: int, user_id: str):

View file

@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config from app.config import config
from app.services.notification_service import NotificationService from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import ( from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder, index_local_folder,
index_uploaded_files, index_uploaded_files,
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
) )
def delete_document_task(self, document_id: int): def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches.""" """Celery task to delete a document and its chunks in batches."""
loop = asyncio.new_event_loop() return run_async_celery_task(lambda: _delete_document_background(document_id))
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_document_background(document_id))
finally:
loop.close()
async def _delete_document_background(document_id: int) -> None: async def _delete_document_background(document_id: int) -> None:
@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None, folder_subtree_ids: list[int] | None = None,
): ):
"""Celery task to delete documents first, then the folder rows.""" """Celery task to delete documents first, then the folder rows."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
try:
loop.run_until_complete(
_delete_folder_documents(document_ids, folder_subtree_ids)
) )
finally:
loop.close()
async def _delete_folder_documents( async def _delete_folder_documents(
@ -209,12 +199,9 @@ async def _delete_folder_documents(
) )
def delete_search_space_task(self, search_space_id: int): def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches.""" """Celery task to delete a search space and heavy child rows in batches."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _delete_search_space_background(search_space_id)
try: )
loop.run_until_complete(_delete_search_space_background(search_space_id))
finally:
loop.close()
async def _delete_search_space_background(search_space_id: int) -> None: async def _delete_search_space_background(search_space_id: int) -> None:
@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space search_space_id: ID of the search space
user_id: ID of the user user_id: ID of the user
""" """
# Create a new event loop for this task return run_async_celery_task(
loop = asyncio.new_event_loop() lambda: _process_extension_document(
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_extension_document(
individual_document_dict, search_space_id, user_id individual_document_dict, search_space_id, user_id
) )
) )
finally:
loop.close()
async def _process_extension_document( async def _process_extension_document(
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space search_space_id: ID of the search space
user_id: ID of the user user_id: ID of the user
""" """
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _process_youtube_video(url, search_space_id, user_id)
)
try:
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
finally:
loop.close()
async def _process_youtube_video(url: str, search_space_id: int, user_id: str): async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e: except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}") logger.warning(f"[process_file_upload] Could not get file size: {e}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete( run_async_celery_task(
_process_file_upload(file_path, filename, search_space_id, user_id) lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
) )
logger.info( logger.info(
f"[process_file_upload] Task completed successfully for: {filename}" f"[process_file_upload] Task completed successfully for: {filename}"
@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}" f"Traceback:\n{traceback.format_exc()}"
) )
raise raise
finally:
loop.close()
async def _process_file_upload( async def _process_file_upload(
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start." "File may have been removed before syncing could start."
) )
# Mark document as failed since file is missing # Mark document as failed since file is missing
loop = asyncio.new_event_loop() run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _mark_document_failed(
try:
loop.run_until_complete(
_mark_document_failed(
document_id, document_id,
"File not found. Please re-upload the file.", "File not found. Please re-upload the file.",
) )
) )
finally:
loop.close()
return return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete( run_async_celery_task(
_process_file_with_document( lambda: _process_file_with_document(
document_id, document_id,
temp_path, temp_path,
filename, filename,
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}" f"Traceback:\n{traceback.format_exc()}"
) )
raise raise
finally:
loop.close()
async def _mark_document_failed(document_id: int, reason: str): async def _mark_document_failed(document_id: int, reason: str):
@ -1119,12 +1080,8 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support) connector_id: ID of the Circleback connector (for deletion support)
""" """
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _process_circleback_meeting(
try:
loop.run_until_complete(
_process_circleback_meeting(
meeting_id, meeting_id,
meeting_name, meeting_name,
markdown_content, markdown_content,
@ -1133,8 +1090,6 @@ def process_circleback_meeting_task(
connector_id, connector_id,
) )
) )
finally:
loop.close()
async def _process_circleback_meeting( async def _process_circleback_meeting(
@ -1291,12 +1246,8 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None, target_file_paths: list[str] | None = None,
): ):
"""Celery task to index a local folder. Config is passed directly — no connector row.""" """Celery task to index a local folder. Config is passed directly — no connector row."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _index_local_folder_async(
try:
loop.run_until_complete(
_index_local_folder_async(
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id, user_id=user_id,
folder_path=folder_path, folder_path=folder_path,
@ -1308,8 +1259,6 @@ def index_local_folder_task(
target_file_paths=target_file_paths, target_file_paths=target_file_paths,
) )
) )
finally:
loop.close()
async def _index_local_folder_async( async def _index_local_folder_async(
@ -1441,11 +1390,8 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic", processing_mode: str = "basic",
): ):
"""Celery task to index files uploaded from the desktop app.""" """Celery task to index files uploaded from the desktop app."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _index_uploaded_folder_files_async(
try:
loop.run_until_complete(
_index_uploaded_folder_files_async(
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id, user_id=user_id,
folder_name=folder_name, folder_name=folder_name,
@ -1456,8 +1402,6 @@ def index_uploaded_folder_files_task(
processing_mode=processing_mode, processing_mode=processing_mode,
) )
) )
finally:
loop.close()
async def _index_uploaded_folder_files_async( async def _index_uploaded_folder_files_async(
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) @celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str): def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space.""" """Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _ai_sort_search_space_async(search_space_id, user_id)
try: )
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
async def _ai_sort_search_space_async(search_space_id: int, user_id: str): async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
) )
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing.""" """Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
) )
finally:
loop.close()
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):

View file

@ -2,14 +2,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from app.celery_app import celery_app from app.celery_app import celery_app
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note from app.services.obsidian_plugin_indexer import upsert_note
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
user_id: str, user_id: str,
) -> None: ) -> None:
"""Process one Obsidian non-markdown attachment asynchronously.""" """Process one Obsidian non-markdown attachment asynchronously."""
loop = asyncio.new_event_loop() return run_async_celery_task(
asyncio.set_event_loop(loop) lambda: _index_obsidian_attachment(
try:
loop.run_until_complete(
_index_obsidian_attachment(
connector_id=connector_id, connector_id=connector_id,
payload_data=payload_data, payload_data=payload_data,
user_id=user_id, user_id=user_id,
) )
) )
finally:
loop.close()
async def _index_obsidian_attachment( async def _index_obsidian_attachment(

View file

@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
import sys import sys
from contextlib import asynccontextmanager
from sqlalchemy import select from sqlalchemy import select
@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config from app.config import config as app_config
from app.db import Podcast, PodcastStatus from app.db import Podcast, PodcastStatus
from app.services.billable_calls import ( from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError, QuotaInsufficientError,
_resolve_agent_billing_for_search_space, _resolve_agent_billing_for_search_space,
billable_call, billable_call,
) )
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,6 +36,13 @@ if sys.platform.startswith("win"):
# ============================================================================= # =============================================================================
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_content_podcast", bind=True) @celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task( def generate_content_podcast_task(
self, self,
@ -46,27 +55,22 @@ def generate_content_podcast_task(
Celery task to generate podcast from source content. Celery task to generate podcast from source content.
Updates existing podcast record created by the tool. Updates existing podcast record created by the tool.
""" """
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
result = loop.run_until_complete( return run_async_celery_task(
_generate_content_podcast( lambda: _generate_content_podcast(
podcast_id, podcast_id,
source_content, source_content,
search_space_id, search_space_id,
user_prompt, user_prompt,
) )
) )
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e: except Exception as e:
logger.error(f"Error generating content podcast: {e!s}") logger.error(f"Error generating content podcast: {e!s}")
loop.run_until_complete(_mark_podcast_failed(podcast_id)) try:
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
except Exception:
logger.exception("Failed to mark podcast %s as failed", podcast_id)
return {"status": "failed", "podcast_id": podcast_id} return {"status": "failed", "podcast_id": podcast_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_podcast_failed(podcast_id: int) -> None: async def _mark_podcast_failed(podcast_id: int) -> None:
@ -148,11 +152,12 @@ async def _generate_content_podcast(
base_model=base_model, base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation", usage_type="podcast_generation",
thread_id=podcast.thread_id,
call_details={ call_details={
"podcast_id": podcast.id, "podcast_id": podcast.id,
"title": podcast.title, "title": podcast.title,
"thread_id": podcast.thread_id,
}, },
billable_session_factory=_celery_billable_session,
): ):
graph_result = await podcaster_graph.ainvoke( graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config initial_state, config=graph_config
@ -173,6 +178,18 @@ async def _generate_content_podcast(
"podcast_id": podcast.id, "podcast_id": podcast.id,
"reason": "premium_quota_exhausted", "reason": "premium_quota_exhausted",
} }
except BillingSettlementError:
logger.exception(
"Podcast %s: premium billing settlement failed",
podcast.id,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_settlement_failed",
}
podcast_transcript = graph_result.get("podcast_transcript", []) podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "") file_path = graph_result.get("final_podcast_file_path", "")
@ -194,7 +211,14 @@ async def _generate_content_podcast(
podcast.podcast_transcript = serializable_transcript podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path podcast.file_location = file_path
podcast.status = PodcastStatus.READY podcast.status = PodcastStatus.READY
logger.info(
"Podcast %s: committing READY transcript_entries=%d file=%s",
podcast.id,
len(serializable_transcript),
file_path,
)
await session.commit() await session.commit()
logger.info("Podcast %s: READY commit complete", podcast.id)
logger.info(f"Successfully generated podcast: {podcast.id}") logger.info(f"Successfully generated podcast: {podcast.id}")

View file

@ -7,7 +7,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app from app.celery_app import celery_app
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.utils.indexing_locks import is_connector_indexing_locked from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
This task runs every minute and triggers indexing for any connector This task runs every minute and triggers indexing for any connector
whose next_scheduled_at time has passed. whose next_scheduled_at time has passed.
""" """
import asyncio return run_async_celery_task(_check_and_trigger_schedules)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_check_and_trigger_schedules())
finally:
loop.close()
async def _check_and_trigger_schedules(): async def _check_and_trigger_schedules():

View file

@ -34,7 +34,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app from app.celery_app import celery_app
from app.config import config from app.config import config
from app.db import Document, DocumentStatus, Notification from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
Also marks associated pending/processing documents as failed. Also marks associated pending/processing documents as failed.
""" """
import asyncio
loop = asyncio.new_event_loop() async def _both() -> None:
asyncio.set_event_loop(loop) await _cleanup_stale_notifications()
await _cleanup_stale_document_processing_notifications()
try: return run_async_celery_task(_both)
loop.run_until_complete(_cleanup_stale_notifications())
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
finally:
loop.close()
async def _cleanup_stale_notifications(): async def _cleanup_stale_notifications():

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@ -18,7 +17,7 @@ from app.db import (
PremiumTokenPurchaseStatus, PremiumTokenPurchaseStatus,
) )
from app.routes import stripe_routes from app.routes import stripe_routes
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
@celery_app.task(name="reconcile_pending_stripe_page_purchases") @celery_app.task(name="reconcile_pending_stripe_page_purchases")
def reconcile_pending_stripe_page_purchases_task(): def reconcile_pending_stripe_page_purchases_task():
"""Recover paid purchases that were left pending due to missed webhook handling.""" """Recover paid purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop() return run_async_celery_task(_reconcile_pending_page_purchases)
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_page_purchases())
finally:
loop.close()
async def _reconcile_pending_page_purchases() -> None: async def _reconcile_pending_page_purchases() -> None:
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
@celery_app.task(name="reconcile_pending_stripe_token_purchases") @celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task(): def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling.""" """Recover paid token purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop() return run_async_celery_task(_reconcile_pending_token_purchases)
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_token_purchases())
finally:
loop.close()
async def _reconcile_pending_token_purchases() -> None: async def _reconcile_pending_token_purchases() -> None:

View file

@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
import sys import sys
from contextlib import asynccontextmanager
from sqlalchemy import select from sqlalchemy import select
@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus from app.db import VideoPresentation, VideoPresentationStatus
from app.services.billable_calls import ( from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError, QuotaInsufficientError,
_resolve_agent_billing_for_search_space, _resolve_agent_billing_for_search_space,
billable_call, billable_call,
) )
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,6 +31,13 @@ if sys.platform.startswith("win"):
) )
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_video_presentation", bind=True) @celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task( def generate_video_presentation_task(
self, self,
@ -41,27 +50,30 @@ def generate_video_presentation_task(
Celery task to generate video presentation from source content. Celery task to generate video presentation from source content.
Updates existing video presentation record created by the tool. Updates existing video presentation record created by the tool.
""" """
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try: try:
result = loop.run_until_complete( return run_async_celery_task(
_generate_video_presentation( lambda: _generate_video_presentation(
video_presentation_id, video_presentation_id,
source_content, source_content,
search_space_id, search_space_id,
user_prompt, user_prompt,
) )
) )
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e: except Exception as e:
logger.error(f"Error generating video presentation: {e!s}") logger.error(f"Error generating video presentation: {e!s}")
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) # Mark FAILED in a fresh loop — the previous loop is closed.
# Swallow secondary failures; the row will simply stay in
# GENERATING and be flushed by the periodic stale cleanup.
try:
run_async_celery_task(
lambda: _mark_video_presentation_failed(video_presentation_id)
)
except Exception:
logger.exception(
"Failed to mark video presentation %s as failed",
video_presentation_id,
)
return {"status": "failed", "video_presentation_id": video_presentation_id} return {"status": "failed", "video_presentation_id": video_presentation_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_video_presentation_failed(video_presentation_id: int) -> None: async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
@ -150,11 +162,12 @@ async def _generate_video_presentation(
base_model=base_model, base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation", usage_type="video_presentation_generation",
thread_id=video_pres.thread_id,
call_details={ call_details={
"video_presentation_id": video_pres.id, "video_presentation_id": video_pres.id,
"title": video_pres.title, "title": video_pres.title,
"thread_id": video_pres.thread_id,
}, },
billable_session_factory=_celery_billable_session,
): ):
graph_result = await video_presentation_graph.ainvoke( graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config initial_state, config=graph_config
@ -175,6 +188,18 @@ async def _generate_video_presentation(
"video_presentation_id": video_pres.id, "video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted", "reason": "premium_quota_exhausted",
} }
except BillingSettlementError:
logger.exception(
"VideoPresentation %s: premium billing settlement failed",
video_pres.id,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_settlement_failed",
}
# Serialize slides (parsed content + audio info merged) # Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", []) slides_raw = graph_result.get("slides", [])
@ -205,7 +230,14 @@ async def _generate_video_presentation(
video_pres.slides = serializable_slides video_pres.slides = serializable_slides
video_pres.scene_codes = serializable_scene_codes video_pres.scene_codes = serializable_scene_codes
video_pres.status = VideoPresentationStatus.READY video_pres.status = VideoPresentationStatus.READY
logger.info(
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
video_pres.id,
len(serializable_slides),
len(serializable_scene_codes),
)
await session.commit() await session.commit()
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
logger.info(f"Successfully generated video presentation: {video_pres.id}") logger.info(f"Successfully generated video presentation: {video_pres.id}")

View file

@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.errors import BusyError from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
@ -96,6 +97,47 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
return min(delay, TURN_CANCELLING_MAX_DELAY_MS) return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
if isinstance(candidate, dict):
value = candidate.get("value", candidate)
return value if isinstance(value, dict) else None
value = getattr(candidate, "value", None)
if isinstance(value, dict):
return value
if isinstance(candidate, (list, tuple)):
for item in candidate:
extracted = _extract_interrupt_value(item)
if extracted is not None:
return extracted
return None
for task in getattr(state, "tasks", ()) or ():
try:
interrupts = getattr(task, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
interrupts = ()
if not interrupts:
extracted = _extract_interrupt_value(task)
if extracted is not None:
return extracted
continue
for interrupt_item in interrupts:
extracted = _extract_interrupt_value(interrupt_item)
if extracted is not None:
return extracted
try:
state_interrupts = getattr(state, "interrupts", ()) or ()
except (AttributeError, IndexError, TypeError):
state_interrupts = ()
extracted = _extract_interrupt_value(state_interrupts)
if extracted is not None:
return extracted
return None
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
@ -518,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None:
) )
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state.
Used by the parallel preflight + agent-build path. The speculative build
closes over the request-scoped ``AsyncSession`` (for the brief connector
discovery / tool-factory window before its CPU work moves into a worker
thread). If preflight reports a 429 we want to fall back to the original
repin reload rebuild path, but we MUST NOT touch ``session`` again
until any in-flight session work owned by the speculative build has
fully settled :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
earlier in this PR (see ``connector_service`` parallel-gather revert).
We simply ``await`` the task and swallow any exception: in this path the
build's outcome is irrelevant — success populates the agent cache (a free
side effect), failure is discarded. The wasted CPU is acceptable since
429 fallbacks are rare and the original sequential code also paid the
full build cost on the same path.
"""
with contextlib.suppress(BaseException):
await task
def _classify_stream_exception( def _classify_stream_exception(
exc: Exception, exc: Exception,
*, *,
@ -655,6 +720,7 @@ async def _stream_agent_events(
fallback_commit_created_by_id: str | None = None, fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None, fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent. """Shared async generator that streams and formats astream_events from the agent.
@ -760,7 +826,18 @@ async def _stream_agent_events(
return event return event
return None return None
async for event in agent.astream_events(input_data, config=config, version="v2"): # Per-invocation runtime context (Phase 1.5). When supplied,
# ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids``
# from ``runtime.context`` instead of its constructor closure — the
# prerequisite that lets the compiled-agent cache (Phase 1) reuse a
# single graph across turns. Astream_events_kwargs stays empty when
# callers leave ``runtime_context`` as ``None`` to preserve the
# legacy code path bit-for-bit.
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
if runtime_context is not None:
astream_kwargs["context"] = runtime_context
async for event in agent.astream_events(input_data, **astream_kwargs):
event_type = event.get("event", "") event_type = event.get("event", "")
if event_type == "on_chat_model_stream": if event_type == "on_chat_model_stream":
@ -1506,10 +1583,10 @@ async def _stream_agent_events(
if isinstance(tool_output, dict) if isinstance(tool_output, dict)
else "Podcast" else "Podcast"
) )
if podcast_status == "processing": if podcast_status in ("pending", "generating", "processing"):
completed_items = [ completed_items = [
f"Title: {podcast_title}", f"Title: {podcast_title}",
"Audio generation started", "Podcast generation started",
"Processing in background...", "Processing in background...",
] ]
elif podcast_status == "already_generating": elif podcast_status == "already_generating":
@ -1518,7 +1595,7 @@ async def _stream_agent_events(
"Podcast already in progress", "Podcast already in progress",
"Please wait for it to complete", "Please wait for it to complete",
] ]
elif podcast_status == "error": elif podcast_status in ("failed", "error"):
error_msg = ( error_msg = (
tool_output.get("error", "Unknown error") tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict) if isinstance(tool_output, dict)
@ -1528,6 +1605,11 @@ async def _stream_agent_events(
f"Title: {podcast_title}", f"Title: {podcast_title}",
f"Error: {error_msg[:50]}", f"Error: {error_msg[:50]}",
] ]
elif podcast_status in ("ready", "success"):
completed_items = [
f"Title: {podcast_title}",
"Podcast ready",
]
else: else:
completed_items = last_active_step_items completed_items = last_active_step_items
yield streaming_service.format_thinking_step( yield streaming_service.format_thinking_step(
@ -1710,20 +1792,28 @@ async def _stream_agent_events(
if isinstance(tool_output, dict) if isinstance(tool_output, dict)
else {"result": tool_output}, else {"result": tool_output},
) )
if ( if isinstance(tool_output, dict) and tool_output.get("status") in (
isinstance(tool_output, dict) "pending",
and tool_output.get("status") == "success" "generating",
"processing",
):
yield streaming_service.format_terminal_info(
f"Podcast queued: {tool_output.get('title', 'Podcast')}",
"success",
)
elif isinstance(tool_output, dict) and tool_output.get("status") in (
"ready",
"success",
): ):
yield streaming_service.format_terminal_info( yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success", "success",
) )
else: elif isinstance(tool_output, dict) and tool_output.get("status") in (
error_msg = ( "failed",
tool_output.get("error", "Unknown error") "error",
if isinstance(tool_output, dict) ):
else "Unknown error" error_msg = tool_output.get("error", "Unknown error")
)
yield streaming_service.format_terminal_info( yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}", f"Podcast generation failed: {error_msg}",
"error", "error",
@ -2165,10 +2255,10 @@ async def _stream_agent_events(
result.agent_called_update_memory = called_update_memory result.agent_called_update_memory = called_update_memory
_log_file_contract("turn_outcome", result) _log_file_contract("turn_outcome", result)
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) interrupt_value = _first_interrupt_value(state)
if is_interrupted: if interrupt_value is not None:
result.is_interrupted = True result.is_interrupted = True
result.interrupt_value = state.tasks[0].interrupts[0].value result.interrupt_value = interrupt_value
yield streaming_service.format_interrupt_request(result.interrupt_value) yield streaming_service.format_interrupt_request(result.interrupt_value)
@ -2292,6 +2382,11 @@ async def stream_new_chat(
) )
_t0 = time.perf_counter() _t0 = time.perf_counter()
# Image-bearing turns force the Auto-pin resolver to filter the
# candidate pool to vision-capable cfgs (and force-repin a
# text-only existing pin). For explicit selections this flag is
# a no-op — the resolver returns the user's chosen id unchanged.
_requires_image_input = bool(user_image_data_urls)
try: try:
llm_config_id = ( llm_config_id = (
await resolve_or_get_pinned_llm_config_id( await resolve_or_get_pinned_llm_config_id(
@ -2300,13 +2395,29 @@ async def stream_new_chat(
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id, user_id=user_id,
selected_llm_config_id=llm_config_id, selected_llm_config_id=llm_config_id,
requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
except ValueError as pin_error: except ValueError as pin_error:
# Auto-pin's "no vision-capable cfg" path raises a ValueError
# whose message we map to the friendly image-input SSE error
# so the user sees the same message regardless of whether
# the gate fired in Auto-mode or in the agent_config check
# below.
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if _requires_image_input and "vision-capable" in str(pin_error)
else "SERVER_ERROR"
)
error_kind = (
"user_error"
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
else "server_error"
)
yield _emit_stream_error( yield _emit_stream_error(
message=str(pin_error), message=str(pin_error),
error_kind="server_error", error_kind=error_kind,
error_code="SERVER_ERROR", error_code=error_code,
) )
yield streaming_service.format_done() yield streaming_service.format_done()
return return
@ -2326,6 +2437,50 @@ async def stream_new_chat(
llm_config_id, llm_config_id,
) )
# Capability safety net: a turn carrying user-uploaded images
# cannot be routed to a chat config that LiteLLM's authoritative
# model map *explicitly* marks as text-only (``supports_vision``
# set to False). The check is intentionally narrow — it only
# fires when LiteLLM is *certain* the model can't accept image
# input. Unknown / unmapped / vision-capable models pass
# through. Without this guard a known-text-only model would 404
# at the provider with ``"No endpoints found that support image
# input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
# failing here lets us return a friendly message that tells the
# user what to change.
if user_image_data_urls and agent_config is not None:
from app.services.provider_capabilities import (
is_known_text_only_chat_model,
)
agent_litellm_params = agent_config.litellm_params or {}
agent_base_model = (
agent_litellm_params.get("base_model")
if isinstance(agent_litellm_params, dict)
else None
)
if is_known_text_only_chat_model(
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,
):
model_label = (
agent_config.config_name or agent_config.model_name or "model"
)
yield _emit_stream_error(
message=(
f"The selected model ({model_label}) does not support "
"image input. Switch to a vision-capable model "
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
"attachment and try again."
),
error_kind="user_error",
error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
)
yield streaming_service.format_done()
return
# Premium quota reservation for pinned premium model only. # Premium quota reservation for pinned premium model only.
_needs_premium_quota = ( _needs_premium_quota = (
agent_config is not None and user_id and agent_config.is_premium agent_config is not None and user_id and agent_config.is_premium
@ -2366,6 +2521,7 @@ async def stream_new_chat(
user_id=user_id, user_id=user_id,
selected_llm_config_id=0, selected_llm_config_id=0,
force_repin_free=True, force_repin_free=True,
requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
except ValueError as pin_error: except ValueError as pin_error:
@ -2440,23 +2596,102 @@ async def stream_new_chat(
# Detecting a 429 here lets us repin BEFORE the planner/classifier/ # Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the # title-generation LLM calls fan out and each independently hit the
# same upstream rate limit. # same upstream rate limit.
if ( #
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
# and is independent of the agent build (CPU-bound, ~5-7s). They used
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
# We now kick off preflight as a background task FIRST, then run the
# synchronous setup work and the agent build in parallel. In the
# success path (the common case) total wall time drops to roughly
# ``max(preflight, build)`` — the preflight finishes during the
# agent compile and we just consume its result. In the rare 429
# path the speculative build is awaited to completion (so its
# session usage is fully released) via
# :func:`_settle_speculative_agent_build`, then discarded, and
# we fall back to the original repin-and-rebuild flow.
preflight_needed = (
requested_llm_config_id == 0 requested_llm_config_id == 0
and llm_config_id < 0 and llm_config_id < 0
and not is_recently_healthy(llm_config_id) and not is_recently_healthy(llm_config_id)
): )
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter() _t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight:{llm_config_id}",
)
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
# Speculative agent build — runs in parallel with the preflight
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
# if preflight reports 429 we will discard this future and rebuild
# against the freshly pinned config below.
agent_build_task = asyncio.create_task(
create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
),
name="agent_build:stream_new_chat",
)
agent: Any = None
if preflight_task is not None:
try: try:
await _preflight_llm(llm) await preflight_task
mark_healthy(llm_config_id) mark_healthy(llm_config_id)
_perf_log.info( _perf_log.info(
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id, llm_config_id,
time.perf_counter() - _t_preflight, time.perf_counter() - _t_preflight,
) )
except Exception as preflight_exc: except Exception as preflight_exc:
# Both branches below need the session: the non-429 path
# may unwind via cleanup that uses ``session``, and the
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
# against it. Wait for the speculative build to release its
# session usage before we proceed.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc): if not _is_provider_rate_limited(preflight_exc):
raise raise
# 429: speculative agent is discarded; run the original
# repin → reload → rebuild path against the freshly
# pinned config.
previous_config_id = llm_config_id previous_config_id = llm_config_id
mark_runtime_cooldown( mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited" previous_config_id, reason="preflight_rate_limited"
@ -2470,6 +2705,7 @@ async def stream_new_chat(
user_id=user_id, user_id=user_id,
selected_llm_config_id=0, selected_llm_config_id=0,
exclude_config_ids={previous_config_id}, exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
except ValueError as pin_error: except ValueError as pin_error:
@ -2518,31 +2754,8 @@ async def stream_new_chat(
"fallback_config_id": llm_config_id, "fallback_config_id": llm_config_id,
}, },
) )
# Rebuild against the new llm/agent_config. Sequential
# Create connector service # here because we no longer have anything to overlap with.
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent( agent = await create_surfsense_deep_agent(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -2558,6 +2771,11 @@ async def stream_new_chat(
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
) )
if agent is None:
# Either no preflight was needed, or preflight succeeded —
# in both cases the speculative build is the agent we want.
agent = await agent_build_task
_perf_log.info( _perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
) )
@ -2804,6 +3022,7 @@ async def stream_new_chat(
from litellm import acompletion from litellm import acompletion
from app.services.llm_router_service import LLMRouterService from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None) _turn_accumulator.set(None)
@ -2824,11 +3043,32 @@ async def stream_new_chat(
model="auto", messages=messages model="auto", messages=messages
) )
else: else:
# Apply the same ``api_base`` cascade chat / vision /
# image-gen call sites use so we never inherit
# ``litellm.api_base`` (commonly set by
# ``AZURE_OPENAI_ENDPOINT``) when the chat config
# itself ships an empty ``api_base``. Without this
# the title-gen on an OpenRouter chat config would
# 404 against the inherited Azure endpoint — see
# ``provider_api_base`` docstring for the same
# bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion( response = await acompletion(
model=llm.model, model=raw_model,
messages=messages, messages=messages,
api_key=getattr(llm, "api_key", None), api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None), api_base=title_api_base,
) )
usage_info = None usage_info = None
@ -2862,6 +3102,18 @@ async def stream_new_chat(
title_emitted = False title_emitted = False
# Build the per-invocation runtime context (Phase 1.5).
# ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
# via ``runtime.context.mentioned_document_ids`` instead of its
# ``__init__`` closure — that way the same compiled-agent instance
# can serve multiple turns with different mention lists.
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
request_id=request_id,
turn_id=stream_result.turn_id,
)
_t_stream_start = time.perf_counter() _t_stream_start = time.perf_counter()
_first_event_logged = False _first_event_logged = False
runtime_rate_limit_recovered = False runtime_rate_limit_recovered = False
@ -2885,6 +3137,7 @@ async def stream_new_chat(
else FilesystemMode.CLOUD else FilesystemMode.CLOUD
), ),
fallback_commit_thread_id=chat_id, fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
): ):
if not _first_event_logged: if not _first_event_logged:
_perf_log.info( _perf_log.info(
@ -2953,6 +3206,7 @@ async def stream_new_chat(
user_id=user_id, user_id=user_id,
selected_llm_config_id=0, selected_llm_config_id=0,
exclude_config_ids={previous_config_id}, exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
@ -3499,21 +3753,75 @@ async def stream_resume_chat(
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets # one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first. # repinned without burning planner/classifier/title calls first.
if ( # See ``stream_new_chat`` for the full rationale on the speculative
# parallel build pattern below.
preflight_needed = (
requested_llm_config_id == 0 requested_llm_config_id == 0
and llm_config_id < 0 and llm_config_id < 0
and not is_recently_healthy(llm_config_id) and not is_recently_healthy(llm_config_id)
): )
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter() _t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight_resume:{llm_config_id}",
)
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent_build_task = asyncio.create_task(
create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
),
name="agent_build:stream_resume",
)
agent: Any = None
if preflight_task is not None:
try: try:
await _preflight_llm(llm) await preflight_task
mark_healthy(llm_config_id) mark_healthy(llm_config_id)
_perf_log.info( _perf_log.info(
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id, llm_config_id,
time.perf_counter() - _t_preflight, time.perf_counter() - _t_preflight,
) )
except Exception as preflight_exc: except Exception as preflight_exc:
# Same session-safety rationale as ``stream_new_chat``.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc): if not _is_provider_rate_limited(preflight_exc):
raise raise
previous_config_id = llm_config_id previous_config_id = llm_config_id
@ -3573,30 +3881,6 @@ async def stream_resume_chat(
"fallback_config_id": llm_config_id, "fallback_config_id": llm_config_id,
}, },
) )
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent( agent = await create_surfsense_deep_agent(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -3610,6 +3894,9 @@ async def stream_resume_chat(
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
) )
if agent is None:
agent = await agent_build_task
_perf_log.info( _perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
) )
@ -3650,6 +3937,16 @@ async def stream_resume_chat(
) )
yield streaming_service.format_data("turn-status", {"status": "busy"}) yield streaming_service.format_data("turn-status", {"status": "busy"})
# Resume path doesn't carry new ``mentioned_document_ids`` —
# those are seeded in the original turn. We still pass a
# context so future middleware extensions (Phase 2) can rely on
# ``runtime.context`` always being populated.
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
request_id=request_id,
turn_id=stream_result.turn_id,
)
_t_stream_start = time.perf_counter() _t_stream_start = time.perf_counter()
_first_event_logged = False _first_event_logged = False
runtime_rate_limit_recovered = False runtime_rate_limit_recovered = False
@ -3670,6 +3967,7 @@ async def stream_resume_chat(
else FilesystemMode.CLOUD else FilesystemMode.CLOUD
), ),
fallback_commit_thread_id=chat_id, fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
): ):
if not _first_event_logged: if not _first_event_logged:
_perf_log.info( _perf_log.info(

View file

@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService, IndexingPipelineService,
PlaceholderInfo, PlaceholderInfo,
) )
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import ( from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from .base import ( from .base import (
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _format_calendar_event_to_markdown(event: dict) -> str:
return GoogleCalendarConnector.format_event_to_markdown(None, event)
def _build_connector_doc( def _build_connector_doc(
event: dict, event: dict,
event_markdown: str, event_markdown: str,
@ -150,7 +152,14 @@ async def index_google_calendar_events(
) )
return 0, 0, f"Connector with ID {connector_id} not found" return 0, 0, f"Connector with ID {connector_id} not found"
# ── Credential building ─────────────────────────────────────── is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
calendar_client = None
composio_service = None
connected_account_id = None
# ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id") connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id: if not connected_account_id:
@ -161,7 +170,7 @@ async def index_google_calendar_events(
{"error_type": "MissingComposioAccount"}, {"error_type": "MissingComposioAccount"},
) )
return 0, 0, "Composio connected_account_id not found" return 0, 0, "Composio connected_account_id not found"
credentials = build_composio_credentials(connected_account_id) composio_service = ComposioService()
else: else:
config_data = connector.config config_data = connector.config
@ -229,6 +238,7 @@ async def index_google_calendar_events(
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
if not is_composio_connector:
calendar_client = GoogleCalendarConnector( calendar_client = GoogleCalendarConnector(
credentials=credentials, credentials=credentials,
session=session, session=session,
@ -300,6 +310,23 @@ async def index_google_calendar_events(
) )
try: try:
if is_composio_connector:
start_dt = parse_date_flexible(start_date_str).replace(
hour=0, minute=0, second=0, microsecond=0
)
end_dt = parse_date_flexible(end_date_str).replace(
hour=23, minute=59, second=59, microsecond=0
)
events, error = await composio_service.get_calendar_events(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
time_min=start_dt.isoformat(),
time_max=end_dt.isoformat(),
max_results=250,
)
if not events and not error:
error = "No events found in the specified date range."
else:
events, error = await calendar_client.get_all_primary_calendar_events( events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str start_date=start_date_str, end_date=end_date_str
) )
@ -381,7 +408,7 @@ async def index_google_calendar_events(
documents_skipped += 1 documents_skipped += 1
continue continue
event_markdown = calendar_client.format_event_to_markdown(event) event_markdown = _format_calendar_event_to_markdown(event)
if not event_markdown.strip(): if not event_markdown.strip():
logger.warning(f"Skipping event with no content: {event_summary}") logger.warning(f"Skipping event with no content: {event_summary}")
documents_skipped += 1 documents_skipped += 1

View file

@ -9,6 +9,8 @@ import asyncio
import logging import logging
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
from sqlalchemy import String, cast, select from sqlalchemy import String, cast, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService, IndexingPipelineService,
PlaceholderInfo, PlaceholderInfo,
) )
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
get_connector_by_id, get_connector_by_id,
update_connector_last_indexed, update_connector_last_indexed,
) )
from app.utils.google_credentials import ( from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
ACCEPTED_DRIVE_CONNECTOR_TYPES = { ACCEPTED_DRIVE_CONNECTOR_TYPES = {
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ComposioDriveClient:
"""Google Drive client facade backed by Composio tool execution.
Composio-managed OAuth connections can execute tools without exposing raw
OAuth tokens through connected account state.
"""
def __init__(
self,
session: AsyncSession,
connector_id: int,
connected_account_id: str,
entity_id: str,
):
self.session = session
self.connector_id = connector_id
self.connected_account_id = connected_account_id
self.entity_id = entity_id
self.composio = ComposioService()
async def list_files(
self,
query: str = "",
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
page_size: int = 100,
page_token: str | None = None,
) -> tuple[list[dict[str, Any]], str | None, str | None]:
params: dict[str, Any] = {
"page_size": min(page_size, 100),
"fields": fields,
}
if query:
params["q"] = query
if page_token:
params["page_token"] = page_token
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_LIST_FILES",
params=params,
entity_id=self.entity_id,
)
if not result.get("success"):
return [], None, result.get("error", "Unknown error")
data = result.get("data", {})
files = []
next_token = None
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
files = inner_data.get("files", [])
next_token = inner_data.get("nextPageToken") or inner_data.get(
"next_page_token"
)
elif isinstance(data, list):
files = data
return files, next_token, None
async def get_file_metadata(
self, file_id: str, fields: str = "*"
) -> tuple[dict[str, Any] | None, str | None]:
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
params={"file_id": file_id, "fields": fields},
entity_id=self.entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
data = result.get("data", {})
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
return inner_data, None
return None, "Could not extract metadata from Composio response"
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
return await self._download_file_content(file_id)
async def download_file_to_disk(
self,
file_id: str,
dest_path: str,
chunksize: int = 5 * 1024 * 1024,
) -> str | None:
del chunksize
content, error = await self.download_file(file_id)
if error:
return error
if content is None:
return "No content returned from Composio"
Path(dest_path).write_bytes(content)
return None
async def export_google_file(
self, file_id: str, mime_type: str
) -> tuple[bytes | None, str | None]:
return await self._download_file_content(file_id, mime_type=mime_type)
async def _download_file_content(
self, file_id: str, mime_type: str | None = None
) -> tuple[bytes | None, str | None]:
params: dict[str, Any] = {"file_id": file_id}
if mime_type:
params["mime_type"] = mime_type
result = await self.composio.execute_tool(
connected_account_id=self.connected_account_id,
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
params=params,
entity_id=self.entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
return self._read_download_result(result.get("data"))
def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
if isinstance(data, bytes):
return data, None
file_path: str | None = None
if isinstance(data, str):
file_path = data
elif isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
for key in ("file_path", "downloaded_file_content", "path", "uri"):
value = inner_data.get(key)
if isinstance(value, str):
file_path = value
break
if isinstance(value, dict):
nested = (
value.get("file_path")
or value.get("downloaded_file_content")
or value.get("path")
or value.get("uri")
or value.get("s3url")
)
if isinstance(nested, str):
file_path = nested
break
if not file_path:
return None, "No file path/content returned from Composio"
if file_path.startswith(("http://", "https://")):
try:
import urllib.request
with urllib.request.urlopen(file_path, timeout=60) as response:
return response.read(), None
except Exception as e:
return None, f"Failed to download Composio file URL: {e!s}"
path_obj = Path(file_path)
if path_obj.is_absolute() or ".composio" in str(path_obj):
if not path_obj.exists():
return None, f"File not found at path: {file_path}"
return path_obj.read_bytes(), None
try:
import base64
return base64.b64decode(file_path), None
except Exception:
return file_path.encode("utf-8"), None
def _build_drive_client_for_connector(
session: AsyncSession,
connector_id: int,
connector: object,
user_id: str,
) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
return None, (
f"Composio connected_account_id not found for connector {connector_id}"
)
return (
ComposioDriveClient(
session,
connector_id,
connected_account_id,
entity_id=f"surfsense_{user_id}",
),
None,
)
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
return None, "SECRET_KEY not configured but credentials are marked as encrypted"
return GoogleDriveClient(session, connector_id), None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -927,34 +1130,17 @@ async def index_google_drive_files(
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
pre_built_credentials = None drive_client, client_error = _build_drive_client_for_connector(
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: session, connector_id, connector, user_id
connected_account_id = connector.config.get("composio_connected_account_id") )
if not connected_account_id: if client_error or not drive_client:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
error_msg, client_error or "Failed to initialize Google Drive client",
"Missing Composio account", "Missing connector credentials",
{"error_type": "MissingComposioAccount"}, {"error_type": "ClientInitializationError"},
)
return 0, 0, error_msg, 0
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
0,
) )
return 0, 0, client_error, 0
connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@ -963,10 +1149,6 @@ async def index_google_drive_files(
from app.services.llm_service import get_vision_llm from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id) vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
if not folder_id: if not folder_id:
error_msg = "folder_id is required for Google Drive indexing" error_msg = "folder_id is required for Google Drive indexing"
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -979,8 +1161,14 @@ async def index_google_drive_files(
folder_tokens = connector.config.get("folder_tokens", {}) folder_tokens = connector.config.get("folder_tokens", {})
start_page_token = folder_tokens.get(target_folder_id) start_page_token = folder_tokens.get(target_folder_id)
is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
can_use_delta = ( can_use_delta = (
use_delta_sync and start_page_token and connector.last_indexed_at not is_composio_connector
and use_delta_sync
and start_page_token
and connector.last_indexed_at
) )
documents_unsupported = 0 documents_unsupported = 0
@ -1051,6 +1239,15 @@ async def index_google_drive_files(
) )
if documents_indexed > 0 or can_use_delta: if documents_indexed > 0 or can_use_delta:
if isinstance(drive_client, ComposioDriveClient):
(
new_token,
token_error,
) = await drive_client.composio.get_drive_start_page_token(
drive_client.connected_account_id,
drive_client.entity_id,
)
else:
new_token, token_error = await get_start_page_token(drive_client) new_token, token_error = await get_start_page_token(drive_client)
if new_token and not token_error: if new_token and not token_error:
await session.refresh(connector) await session.refresh(connector)
@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
) )
return 0, error_msg return 0, error_msg
pre_built_credentials = None drive_client, client_error = _build_drive_client_for_connector(
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: session, connector_id, connector, user_id
connected_account_id = connector.config.get("composio_connected_account_id") )
if not connected_account_id: if client_error or not drive_client:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
error_msg, client_error or "Failed to initialize Google Drive client",
"Missing Composio account", "Missing connector credentials",
{"error_type": "MissingComposioAccount"}, {"error_type": "ClientInitializationError"},
)
return 0, error_msg
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
) )
return 0, client_error
connector_enable_summary = getattr(connector, "enable_summary", True) connector_enable_summary = getattr(connector, "enable_summary", True)
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False) connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
from app.services.llm_service import get_vision_llm from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id) vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
file, error = await get_file_by_id(drive_client, file_id) file, error = await get_file_by_id(drive_client, file_id)
if error or not file: if error or not file:
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}" error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
@ -1276,30 +1454,16 @@ async def index_google_drive_selected_files(
) )
return 0, 0, [error_msg] return 0, 0, [error_msg]
pre_built_credentials = None drive_client, client_error = _build_drive_client_for_connector(
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: session, connector_id, connector, user_id
connected_account_id = connector.config.get("composio_connected_account_id") )
if not connected_account_id: if client_error or not drive_client:
error_msg = f"Composio connected_account_id not found for connector {connector_id}" error_msg = client_error or "Failed to initialize Google Drive client"
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
error_msg, error_msg,
"Missing Composio account", "Missing connector credentials",
{"error_type": "MissingComposioAccount"}, {"error_type": "ClientInitializationError"},
)
return 0, 0, [error_msg]
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
error_msg = (
"SECRET_KEY not configured but credentials are marked as encrypted"
)
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
) )
return 0, 0, [error_msg] return 0, 0, [error_msg]
@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
from app.services.llm_service import get_vision_llm from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id) vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
indexed, skipped, unsupported, errors = await _index_selected_files( indexed, skipped, unsupported, errors = await _index_selected_files(
drive_client, drive_client,
session, session,

View file

@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService, IndexingPipelineService,
PlaceholderInfo, PlaceholderInfo,
) )
from app.services.composio_service import ComposioService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import ( from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from .base import ( from .base import (
calculate_date_range, calculate_date_range,
@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _normalize_composio_gmail_message(message: dict) -> dict:
if message.get("payload"):
return message
headers = []
header_values = {
"Subject": message.get("subject"),
"From": message.get("from") or message.get("sender"),
"To": message.get("to") or message.get("recipient"),
"Date": message.get("date"),
}
for name, value in header_values.items():
if value:
headers.append({"name": name, "value": value})
return {
**message,
"id": message.get("id")
or message.get("message_id")
or message.get("messageId"),
"threadId": message.get("threadId") or message.get("thread_id"),
"payload": {"headers": headers},
"snippet": message.get("snippet", ""),
"messageText": message.get("messageText") or message.get("body") or "",
}
def _format_gmail_message_to_markdown(message: dict) -> str:
headers = {
header.get("name", "").lower(): header.get("value", "")
for header in message.get("payload", {}).get("headers", [])
if isinstance(header, dict)
}
subject = headers.get("subject", "No Subject")
from_email = headers.get("from", "Unknown Sender")
to_email = headers.get("to", "Unknown Recipient")
date_str = headers.get("date", "Unknown Date")
message_text = (
message.get("messageText")
or message.get("body")
or message.get("text")
or message.get("snippet", "")
)
return (
f"# {subject}\n\n"
f"**From:** {from_email}\n"
f"**To:** {to_email}\n"
f"**Date:** {date_str}\n\n"
f"## Message Content\n\n{message_text}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {message.get('id', 'Unknown')}\n"
f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
)
def _build_connector_doc( def _build_connector_doc(
message: dict, message: dict,
markdown_content: str, markdown_content: str,
@ -162,7 +216,14 @@ async def index_google_gmail_messages(
) )
return 0, 0, error_msg return 0, 0, error_msg
# ── Credential building ─────────────────────────────────────── is_composio_connector = (
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
)
gmail_connector = None
composio_service = None
connected_account_id = None
# ── Credential/client building ────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id") connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id: if not connected_account_id:
@ -173,7 +234,7 @@ async def index_google_gmail_messages(
{"error_type": "MissingComposioAccount"}, {"error_type": "MissingComposioAccount"},
) )
return 0, 0, "Composio connected_account_id not found" return 0, 0, "Composio connected_account_id not found"
credentials = build_composio_credentials(connected_account_id) composio_service = ComposioService()
else: else:
config_data = connector.config config_data = connector.config
@ -241,6 +302,7 @@ async def index_google_gmail_messages(
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
if not is_composio_connector:
gmail_connector = GoogleGmailConnector( gmail_connector = GoogleGmailConnector(
credentials, session, user_id, connector_id credentials, session, user_id, connector_id
) )
@ -254,6 +316,55 @@ async def index_google_gmail_messages(
f"Fetching emails for connector {connector_id} " f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}" f"from {calculated_start_date} to {calculated_end_date}"
) )
if is_composio_connector:
query_parts = []
if calculated_start_date:
query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
if calculated_end_date:
query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
query = " ".join(query_parts)
messages = []
page_token = None
error = None
while len(messages) < max_messages:
page_size = min(50, max_messages - len(messages))
(
page_messages,
page_token,
_estimate,
page_error,
) = await composio_service.get_gmail_messages(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=page_size,
page_token=page_token,
)
if page_error:
error = page_error
break
for page_message in page_messages:
message_id = (
page_message.get("id")
or page_message.get("message_id")
or page_message.get("messageId")
)
if message_id:
(
detail,
detail_error,
) = await composio_service.get_gmail_message_detail(
connected_account_id=connected_account_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if not detail_error and isinstance(detail, dict):
page_message = detail
messages.append(_normalize_composio_gmail_message(page_message))
if not page_token:
break
else:
messages, error = await gmail_connector.get_recent_messages( messages, error = await gmail_connector.get_recent_messages(
max_results=max_messages, max_results=max_messages,
start_date=calculated_start_date, start_date=calculated_start_date,
@ -326,7 +437,12 @@ async def index_google_gmail_messages(
documents_skipped += 1 documents_skipped += 1
continue continue
markdown_content = gmail_connector.format_message_to_markdown(message) if is_composio_connector:
markdown_content = _format_gmail_message_to_markdown(message)
else:
markdown_content = gmail_connector.format_message_to_markdown(
message
)
if not markdown_content.strip(): if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {message_id}") logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1 documents_skipped += 1

View file

@ -1,6 +1,6 @@
[project] [project]
name = "surf-new-backend" name = "surf-new-backend"
version = "0.0.19" version = "0.0.20"
description = "SurfSense Backend" description = "SurfSense Backend"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
@ -71,11 +71,11 @@ dependencies = [
"langchain>=1.2.13", "langchain>=1.2.13",
"langgraph>=1.1.3", "langgraph>=1.1.3",
"langchain-community>=0.4.1", "langchain-community>=0.4.1",
"deepagents>=0.4.12",
"stripe>=15.0.0", "stripe>=15.0.0",
"azure-ai-documentintelligence>=1.0.2", "azure-ai-documentintelligence>=1.0.2",
"litellm>=1.83.7", "litellm>=1.83.7",
"langchain-litellm>=0.6.4", "langchain-litellm>=0.6.4",
"deepagents>=0.4.12,<0.5",
] ]
[dependency-groups] [dependency-groups]

View file

@ -0,0 +1,558 @@
"""End-to-end smoke test for vision / image config wiring.
Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
exercises every chat / vision / image-generation config + the OpenRouter
dynamic catalog. For each config the script:
1. Reports the resolver classification (catalog-allow vs strict-block).
2. Optionally fires a tiny live API call against the provider:
- Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
``"reply with one word: ok"``.
- Vision configs: same, against the dedicated vision router pool.
- Image-gen configs: ``litellm.aimage_generation`` with a single tiny
prompt and ``n=1``.
- OpenRouter integration: samples one chat, one vision, one image-gen
model from the dynamically fetched catalog.
Usage::
python -m scripts.verify_chat_image_capability # capability + connectivity
python -m scripts.verify_chat_image_capability --no-live # capability resolver only
The script is meant to be runnable from the repository root or from
``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
end so it's usable as a CI smoke check too.
Live-mode caveat: each successful call costs a small amount of provider
credit (a few tokens or one tiny generated image per config). The
default size for image generation is ``1024x1024`` because Azure
GPT-image deployments reject smaller sizes; OpenRouter image-gen models
generally accept the same size.
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Any
# Bootstrap the surfsense_backend package on sys.path so the script runs
# from the repo root or from `surfsense_backend/` interchangeably.
_HERE = os.path.dirname(os.path.abspath(__file__))
_BACKEND_ROOT = os.path.dirname(_HERE)
if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
import litellm # noqa: E402
from app.config import config # noqa: E402
from app.services.openrouter_integration_service import ( # noqa: E402
_OPENROUTER_DYNAMIC_MARKER,
OpenRouterIntegrationService,
)
from app.services.provider_api_base import resolve_api_base # noqa: E402
from app.services.provider_capabilities import ( # noqa: E402
derive_supports_image_input,
is_known_text_only_chat_model,
)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
# Quiet down LiteLLM's verbose router/cost logs so the script output is
# scannable.
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
logging.getLogger("litellm").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
# 1x1 transparent PNG — used as the cheapest possible vision payload.
_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
# ---------------------------------------------------------------------------
# Result accounting
# ---------------------------------------------------------------------------
@dataclass
class ProbeResult:
label: str
surface: str
config_id: int | str
capability_ok: bool | None = None
capability_note: str = ""
live_ok: bool | None = None
live_note: str = ""
duration_s: float = 0.0
@dataclass
class Report:
results: list[ProbeResult] = field(default_factory=list)
def add(self, r: ProbeResult) -> None:
self.results.append(r)
def render(self) -> int:
passed = failed = skipped = 0
print()
print("=" * 92)
print(
f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
)
print("-" * 92)
for r in self.results:
def _flag(value: bool | None) -> str:
if value is None:
return "skip"
return "ok" if value else "fail"
cap = _flag(r.capability_ok)
live = _flag(r.live_ok)
if r.capability_ok is False or r.live_ok is False:
failed += 1
elif r.capability_ok is None and r.live_ok is None:
skipped += 1
else:
passed += 1
print(
f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
f"{r.duration_s:>5.2f}s {r.label}"
)
if r.capability_note:
print(f" cap: {r.capability_note}")
if r.live_note:
print(f" live: {r.live_note}")
print("-" * 92)
print(
f"Total: {passed} ok / {failed} fail / {skipped} skip "
f"(of {len(self.results)} probes)"
)
print("=" * 92)
return failed
# ---------------------------------------------------------------------------
# Capability probes (no network)
# ---------------------------------------------------------------------------
def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
"""For chat configs the catalog flag is *expected* True (vision-capable
pool). The probe reports both the resolver value and the strict
safety-net value to surface any drift between them."""
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
note = f"derive={cap} strict_block={block}"
if not cap and not block:
# Resolver said False but strict gate is also False — that means
# OR modalities published [text] explicitly. Surface it.
note += " (OR modality says text-only)"
# We accept a True derive *or* (False derive AND False block) as
# 'capability ok' — either way, the streaming task will flow through.
ok = cap or not block
return ok, note
def _build_chat_model_string(cfg: dict) -> str:
if cfg.get("custom_provider"):
return f"{cfg['custom_provider']}/{cfg['model_name']}"
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
return f"{prefix}/{cfg['model_name']}"
# ---------------------------------------------------------------------------
# Live probes (network calls)
# ---------------------------------------------------------------------------
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
model_string = _build_chat_model_string(cfg)
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=model_string.split("/", 1)[0],
config_api_base=cfg.get("api_base") or None,
)
kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "reply with one word: ok"},
{
"type": "image_url",
"image_url": {"url": _TINY_PNG_DATA_URL},
},
],
}
],
"max_tokens": 16,
"timeout": 60,
}
if api_base:
kwargs["api_base"] = api_base
if cfg.get("litellm_params"):
# Strip pricing keys — they're tracking-only and confuse some
# provider validators (e.g. azure/openai reject unknown kwargs
# in strict mode).
merged = {
k: v
for k, v in dict(cfg["litellm_params"]).items()
if k
not in {
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_pixel",
"output_cost_per_pixel",
}
}
kwargs.update(merged)
try:
resp = await litellm.acompletion(**kwargs)
except Exception as exc:
return False, f"{type(exc).__name__}: {exc}"
text = resp.choices[0].message.content if resp.choices else ""
return True, f"got reply ({(text or '').strip()[:40]!r})"
# Gemini image models occasionally return zero-length ``data`` for the
# minimal "red dot on white" prompt (provider-side safety / empty-output
# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
# the request itself succeeds). Use a more naturalistic prompt and
# retry once with a different one before giving up.
_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
"A simple icon of a coffee cup, flat illustration",
"A small green leaf on a white background",
)
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"""Generate one tiny image to verify the deployment is reachable."""
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
if cfg.get("custom_provider"):
prefix = cfg["custom_provider"]
else:
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
model_string = f"{prefix}/{cfg['model_name']}"
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=prefix,
config_api_base=cfg.get("api_base") or None,
)
base_kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
"n": 1,
"size": "1024x1024",
"timeout": 120,
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg.get("api_version"):
base_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
base_kwargs.update(
{
k: v
for k, v in dict(cfg["litellm_params"]).items()
if k
not in {
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_pixel",
"output_cost_per_pixel",
}
}
)
last_note = ""
for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
try:
resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
except Exception as exc:
last_note = f"{type(exc).__name__}: {exc}"
continue
data_count = len(getattr(resp, "data", None) or [])
if data_count > 0:
return True, (
f"received {data_count} image(s) on attempt {attempt} "
f"(prompt={prompt!r})"
)
last_note = (
f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
)
return False, last_note
# ---------------------------------------------------------------------------
# Probe drivers
# ---------------------------------------------------------------------------
def _is_or_dynamic(cfg: dict) -> bool:
return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
async def probe_chat_configs(report: Report, *, live: bool) -> None:
print("\n[chat configs from global_llm_configs (YAML-static)]")
for cfg in config.GLOBAL_LLM_CONFIGS:
# Skip OR dynamic entries here — handled in the OR section so
# the YAML / OR split stays clear in the report.
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="chat-yaml",
config_id=cfg.get("id"),
)
cap_ok, cap_note = _probe_chat_capability(cfg)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await _live_chat_image_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_vision_configs(report: Report, *, live: bool) -> None:
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="vision",
config_id=cfg.get("id"),
)
# For vision configs, capability is implied — they're in the
# dedicated vision pool. Run the same resolver to flag any
# surprise disagreement.
cap_ok, cap_note = _probe_chat_capability(cfg)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await _live_chat_image_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
print(
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
)
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="image-gen",
config_id=cfg.get("id"),
)
# Image gen configs don't have a "supports_image_input" flag;
# the catalog tracks output, not input. Mark capability as None
# (skip) for the report.
if live:
t0 = time.perf_counter()
ok, note = await _live_image_gen_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
"""Sample one chat (vision-capable), one vision, one image-gen model
from the live OpenRouter catalogue. Doesn't iterate the full pool
(would be hundreds of probes); just validates the integration end-
to-end on a representative model from each surface."""
print("\n[OpenRouter integration: sampled probes]")
settings = config.OPENROUTER_INTEGRATION_SETTINGS
if not settings:
report.add(
ProbeResult(
label="OpenRouter integration",
surface="openrouter",
config_id="settings",
capability_ok=None,
capability_note="openrouter_integration disabled in YAML — skipping",
live_ok=None,
)
)
return
service = OpenRouterIntegrationService.get_instance()
or_chat = [
c
for c in config.GLOBAL_LLM_CONFIGS
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
]
or_vision = [
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
]
or_image_gen = [
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
]
# Pick one representative per provider family per surface so a single
# broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
# surfaces independently of the others. Each needle matches the
# OpenRouter ``model_name`` prefix; the first match wins.
def _pick_first(pool: list[dict], needle: str) -> dict | None:
for c in pool:
if (c.get("model_name") or "").lower().startswith(needle):
return c
return None
chat_picks = [
("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
("or-chat", _pick_first(or_chat, "anthropic/claude")),
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
]
vision_picks = [
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
("or-vision", _pick_first(or_vision, "anthropic/claude")),
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
]
image_picks = [
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
# / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
# the actual prefix.
("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
]
print(
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
)
for surface, picked in chat_picks + vision_picks + image_picks:
if not picked:
report.add(
ProbeResult(
label=f"<no candidate for {surface}>",
surface=surface,
config_id="-",
capability_ok=None,
capability_note="no candidate found in OR catalog",
)
)
continue
runner = (
_live_image_gen_call if surface == "or-image" else _live_chat_image_call
)
result = ProbeResult(
label=str(picked.get("model_name")),
surface=surface,
config_id=picked.get("id"),
)
if surface != "or-image":
cap_ok, cap_note = _probe_chat_capability(picked)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await runner(picked)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
async def main(args: argparse.Namespace) -> int:
print("Loaded global configs:")
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
# Initialize the OpenRouter integration so the catalog is populated
# (this is what main.py does at startup). It's idempotent.
if config.OPENROUTER_INTEGRATION_SETTINGS:
try:
from app.config import initialize_openrouter_integration
initialize_openrouter_integration()
except Exception as exc:
print(f" WARNING: OpenRouter integration init failed: {exc}")
print(
f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
)
report = Report()
if not args.skip_chat:
await probe_chat_configs(report, live=args.live)
if not args.skip_vision:
await probe_vision_configs(report, live=args.live)
if not args.skip_image_gen:
await probe_image_gen_configs(report, live=args.live)
if not args.skip_openrouter:
await probe_openrouter_catalog(report, live=args.live)
failed = report.render()
return 1 if failed else 0
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--no-live",
dest="live",
action="store_false",
help="Skip live API calls — capability resolver only.",
)
parser.set_defaults(live=True)
parser.add_argument("--skip-chat", action="store_true")
parser.add_argument("--skip-vision", action="store_true")
parser.add_argument("--skip-image-gen", action="store_true")
parser.add_argument("--skip-openrouter", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
sys.exit(asyncio.run(main(args)))

View file

@ -0,0 +1,268 @@
"""Regression tests for the compiled-agent cache.
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
build-failure non-caching) and the cache-key signature helpers that
``create_surfsense_deep_agent`` relies on. The integration with
``create_surfsense_deep_agent`` is covered separately by the streaming
contract tests; this module focuses on the primitives so a regression
in the cache implementation is caught before it reaches the agent
factory.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
from app.agents.new_chat.agent_cache import (
flags_signature,
reload_for_tests,
stable_hash,
system_prompt_hash,
tools_signature,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# stable_hash + signature helpers
# ---------------------------------------------------------------------------
def test_stable_hash_is_deterministic_across_calls() -> None:
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
assert a == b
def test_stable_hash_changes_when_any_part_changes() -> None:
base = stable_hash("v1", 42, "thread-9")
assert stable_hash("v1", 42, "thread-10") != base
assert stable_hash("v2", 42, "thread-9") != base
assert stable_hash("v1", 43, "thread-9") != base
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
"""Two tool lists with the same surface must hash identically.
The cache key MUST NOT change when the underlying ``BaseTool``
instances are different Python objects (a fresh request constructs
fresh tool instances every time). Hashing on ``(name, description)``
keeps the cache hot across requests with identical tool surfaces.
"""
@dataclass
class FakeTool:
name: str
description: str
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
sig_a = tools_signature(
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
)
sig_b = tools_signature(
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_a == sig_b, "tool order must not affect the signature"
# Adding a tool rotates the key.
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
sig_c = tools_signature(
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_c != sig_a
def test_tools_signature_rotates_when_connector_set_changes() -> None:
@dataclass
class FakeTool:
name: str
description: str
tools = [FakeTool("a", "x")]
base = tools_signature(
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
)
added = tools_signature(
tools,
available_connectors=["NOTION", "SLACK"],
available_document_types=["FILE"],
)
assert base != added, "adding a connector must rotate the cache key"
def test_flags_signature_changes_when_flag_flips() -> None:
@dataclass(frozen=True)
class Flags:
a: bool = True
b: bool = False
base = flags_signature(Flags())
flipped = flags_signature(Flags(b=True))
assert base != flipped
def test_system_prompt_hash_is_stable_and_distinct() -> None:
p1 = "You are a helpful assistant."
p2 = "You are a helpful assistant!" # one-character delta
assert system_prompt_hash(p1) == system_prompt_hash(p1)
assert system_prompt_hash(p1) != system_prompt_hash(p2)
# ---------------------------------------------------------------------------
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is b, "cache must return the SAME object across hits"
assert builds == 1, "builder must run exactly once"
@pytest.mark.asyncio
async def test_cache_different_keys_get_different_instances() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k1", builder=builder)
b = await cache.get_or_build("k2", builder=builder)
assert a is not b
@pytest.mark.asyncio
async def test_cache_stale_entries_get_rebuilt() -> None:
# ttl=0 means every read sees the entry as immediately stale.
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "stale entry must rebuild a fresh instance"
assert builds == 2
@pytest.mark.asyncio
async def test_cache_evicts_lru_when_full() -> None:
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("a", builder=builder)
_ = await cache.get_or_build("b", builder=builder)
# Re-touch "a" so "b" is now the LRU victim.
a_again = await cache.get_or_build("a", builder=builder)
assert a_again is a
# Inserting "c" should evict "b" (LRU), not "a".
_ = await cache.get_or_build("c", builder=builder)
assert cache.stats()["size"] == 2
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
a_hit = await cache.get_or_build("a", builder=builder)
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
@pytest.mark.asyncio
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
"""Two concurrent get_or_build calls on the same key must share one builder."""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
build_started = asyncio.Event()
builds = 0
async def slow_builder() -> object:
nonlocal builds
builds += 1
build_started.set()
# Yield control so the second waiter can race against us.
await asyncio.sleep(0.05)
return object()
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
# Wait until the first builder has started, then race a second waiter.
await build_started.wait()
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
a, b = await asyncio.gather(task_a, task_b)
assert a is b, "coalesced waiters must observe the same value"
assert builds == 1, "concurrent cold misses must collapse to ONE build"
@pytest.mark.asyncio
async def test_cache_does_not_store_failed_builds() -> None:
"""A builder that raises must NOT poison the cache.
The next caller for the same key must run the builder again (not
re-raise the cached exception).
"""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
attempts = 0
async def flaky_builder() -> object:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("transient")
return object()
with pytest.raises(RuntimeError, match="transient"):
await cache.get_or_build("k", builder=flaky_builder)
# Second call must retry — not re-raise the cached exception.
value = await cache.get_or_build("k", builder=flaky_builder)
assert value is not None
assert attempts == 2
@pytest.mark.asyncio
async def test_cache_invalidate_drops_entry() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k", builder=builder)
assert cache.invalidate("k") is True
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "post-invalidation lookup must rebuild"
@pytest.mark.asyncio
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
async def builder() -> object:
return object()
await cache.get_or_build("user:1:thread:1", builder=builder)
await cache.get_or_build("user:1:thread:2", builder=builder)
await cache.get_or_build("user:2:thread:1", builder=builder)
removed = cache.invalidate_prefix("user:1:")
assert removed == 2
assert cache.stats()["size"] == 1
# The user:2 entry must still be hot (no rebuild).
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
assert survivor_value is not None

View file

@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE", "SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL", "SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]: ]:
monkeypatch.delenv(name, raising=False) monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch) _clear_all(monkeypatch)
flags = reload_for_tests() flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags) assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False assert flags.enable_context_editing is True
assert flags.enable_compaction_v2 is True
assert flags.enable_retry_after is True
assert flags.enable_model_fallback is False
assert flags.enable_model_call_limit is True
assert flags.enable_tool_call_limit is True
assert flags.enable_tool_call_repair is True
assert flags.enable_doom_loop is True
assert flags.enable_permission is True
assert flags.enable_busy_mutex is True
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
# ``db_session`` refactor landed). The companion gp-subagent share
# flag stays default-off pending data on cold-miss frequency.
assert flags.enable_agent_cache is True
assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags( def test_master_kill_switch_overrides_individual_flags(
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL", "enable_otel": "SURFSENSE_ENABLE_OTEL",
} }
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items(): for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch) _clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true") monkeypatch.setenv(env_name, "false")
flags = reload_for_tests() flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False

View file

@ -0,0 +1,344 @@
"""Tests for ``FlattenSystemMessageMiddleware``.
The middleware exists to defend against Anthropic's "Found 5 cache_control
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
the system message and the OpenRouterAnthropic adapter redistributes
``cache_control`` across all of them. The flattening collapses every
all-text system content list to a single string before the LLM call.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
_flatten_text_blocks,
_flattened_request,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# _flatten_text_blocks — pure helper, the heart of the middleware.
# ---------------------------------------------------------------------------
class TestFlattenTextBlocks:
def test_joins_text_blocks_with_double_newline(self) -> None:
blocks = [
{"type": "text", "text": "<surfsense base>"},
{"type": "text", "text": "<filesystem section>"},
{"type": "text", "text": "<skills section>"},
]
assert (
_flatten_text_blocks(blocks)
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
)
def test_handles_single_text_block(self) -> None:
blocks = [{"type": "text", "text": "only one"}]
assert _flatten_text_blocks(blocks) == "only one"
def test_handles_empty_list(self) -> None:
assert _flatten_text_blocks([]) == ""
def test_passes_through_bare_string_blocks(self) -> None:
# LangChain content can mix bare strings and dict blocks.
blocks = ["raw string", {"type": "text", "text": "dict block"}]
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
def test_returns_none_for_image_block(self) -> None:
# System messages with images are rare — but we never want to
# silently lose the image payload by joining as text.
blocks = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_for_non_dict_non_str_block(self) -> None:
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_field_missing(self) -> None:
blocks = [{"type": "text"}] # no ``text`` key
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_is_not_string(self) -> None:
blocks = [{"type": "text", "text": ["nested", "list"]}]
assert _flatten_text_blocks(blocks) is None
def test_drops_cache_control_from_inner_blocks(self) -> None:
# The whole point: existing cache_control on inner blocks is
# discarded so LiteLLM's ``cache_control_injection_points`` can
# re-attach exactly one breakpoint after flattening.
blocks = [
{"type": "text", "text": "first"},
{
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
},
]
flattened = _flatten_text_blocks(blocks)
assert flattened == "first\n\nsecond"
assert "cache_control" not in flattened # type: ignore[operator]
# ---------------------------------------------------------------------------
# _flattened_request — decides when to override and when to no-op.
# ---------------------------------------------------------------------------
def _make_request(system_message: SystemMessage | None) -> Any:
"""Build a minimal ModelRequest stub. We only need .system_message
and .override(system_message=...) the middleware never touches
other fields.
"""
request = MagicMock()
request.system_message = system_message
def override(**kwargs: Any) -> Any:
new_request = MagicMock()
new_request.system_message = kwargs.get(
"system_message", request.system_message
)
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
return new_request
request.override = override
return request
class TestFlattenedRequest:
def test_collapses_multi_block_system_to_string(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<base>"},
{"type": "text", "text": "<todo>"},
{"type": "text", "text": "<filesystem>"},
{"type": "text", "text": "<skills>"},
{"type": "text", "text": "<subagents>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message, SystemMessage)
assert flattened.system_message.content == (
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
)
def test_no_op_for_string_content(self) -> None:
sys = SystemMessage(content="already a string")
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_for_single_block_list(self) -> None:
# One block already produces one breakpoint — no need to flatten.
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_when_system_message_missing(self) -> None:
request = _make_request(None)
assert _flattened_request(request) is None
def test_no_op_when_list_contains_non_text_block(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
)
request = _make_request(sys)
assert _flattened_request(request) is None
def test_preserves_additional_kwargs_and_metadata(self) -> None:
# Defensive: nothing in the current chain sets these on a system
# message, but losing them silently when something does in the
# future would be a regression. ``name`` in particular is the only
# ``additional_kwargs`` field that ChatLiteLLM's
# ``_convert_message_to_dict`` propagates onto the wire.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
],
additional_kwargs={"name": "surfsense_system", "x": 1},
response_metadata={"tokens": 42},
)
sys.id = "sys-msg-1"
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert flattened.system_message.content == "a\n\nb"
assert flattened.system_message.additional_kwargs == {
"name": "surfsense_system",
"x": 1,
}
assert flattened.system_message.response_metadata == {"tokens": 42}
assert flattened.system_message.id == "sys-msg-1"
def test_idempotent_when_run_twice(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
request = _make_request(sys)
first = _flattened_request(request)
assert first is not None
# Second pass on the already-flattened request should be a no-op.
# We re-wrap in a request stub since the helper inspects
# ``request.system_message.content``.
second_request = _make_request(first.system_message)
assert _flattened_request(second_request) is None
# ---------------------------------------------------------------------------
# Middleware integration — verify the handler sees a flattened request.
# ---------------------------------------------------------------------------
class TestMiddlewareWrap:
@pytest.mark.asyncio
async def test_async_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = await mw.awrap_model_call(request, handler)
assert result == "ok"
assert isinstance(captured["request"].system_message, SystemMessage)
assert captured["request"].system_message.content == "alpha\n\nbeta"
@pytest.mark.asyncio
async def test_async_passes_through_when_already_string(self) -> None:
sys = SystemMessage(content="just a string")
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
await mw.awrap_model_call(request, handler)
# Same request object: no override happened.
assert captured["request"] is request
def test_sync_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = mw.wrap_model_call(request, handler)
assert result == "ok"
assert captured["request"].system_message.content == "alpha\n\nbeta"
def test_sync_passes_through_when_no_system_message(self) -> None:
request = _make_request(None)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
mw.wrap_model_call(request, handler)
assert captured["request"] is request
# ---------------------------------------------------------------------------
# Regression guard — pin the worst-case shape that triggered the
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
# downstream cache_control_injection_points can only place 1 breakpoint
# on the system message regardless of provider redistribution quirks.
# ---------------------------------------------------------------------------
def test_regression_five_block_system_collapses_to_one_block() -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
{"type": "text", "text": "<TodoListMiddleware section>"},
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
{"type": "text", "text": "<SkillsMiddleware section>"},
{"type": "text", "text": "<SubAgentMiddleware section>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message.content, str)
# The exact join doesn't matter for the cache_control accounting —
# only that there is exactly ONE content block when LiteLLM's
# AnthropicCacheControlHook later targets ``role: system``.
assert "<surfsense base" in flattened.system_message.content
assert "<SubAgentMiddleware" in flattened.system_message.content
def test_regression_human_message_not_modified() -> None:
# Sanity: the middleware MUST NOT touch user messages — only the
# system message. Multi-block user content is the path that carries
# image attachments and would lose its image_url block on
# accidental flatten.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
user = HumanMessage(
content=[
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
]
)
request = _make_request(sys)
request.messages = [user]
flattened = _flattened_request(request)
assert flattened is not None
# System flattened to string …
assert isinstance(flattened.system_message.content, str)
# … user message is untouched (the helper does not even look at it).
assert flattened.messages == [user]
assert isinstance(user.content, list)
assert len(user.content) == 2

Some files were not shown because too many files have changed in this diff Show more