mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/ui-revamp
This commit is contained in:
commit
4e8c552440
142 changed files with 14603 additions and 6056 deletions
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.0.19
|
0.0.20
|
||||||
|
|
|
||||||
|
|
@ -308,6 +308,24 @@ STT_SERVICE=local/base
|
||||||
# Advanced (optional)
|
# Advanced (optional)
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# New-chat agent feature flags
|
||||||
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
|
SURFSENSE_ENABLE_BUSY_MUTEX=true
|
||||||
|
SURFSENSE_ENABLE_SKILLS=true
|
||||||
|
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
|
||||||
|
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
|
||||||
|
SURFSENSE_ENABLE_ACTION_LOG=true
|
||||||
|
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||||
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
# Periodic connector sync interval (default: 5m)
|
# Periodic connector sync interval (default: 5m)
|
||||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -324,3 +324,30 @@ LANGSMITH_PROJECT=surfsense
|
||||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
# Comma-separated allowlist of plugin entry-point names
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
|
||||||
|
# on a cold turn) is reused across subsequent turns on the same thread,
|
||||||
|
# collapsing it to a microsecond hash lookup. All connector tools acquire
|
||||||
|
# their own short-lived DB session per call (Phase 2 refactor) so a cached
|
||||||
|
# closure is safe to share across requests. Flip OFF only as a last-resort
|
||||||
|
# rollback if you suspect cache-related staleness.
|
||||||
|
# SURFSENSE_ENABLE_AGENT_CACHE=true
|
||||||
|
|
||||||
|
# Cache capacity (max number of compiled-agent entries kept in memory)
|
||||||
|
# and TTL per entry (seconds). Working set is typically one entry per
|
||||||
|
# active thread on this replica; tune up for very large deployments.
|
||||||
|
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
|
||||||
|
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Connector discovery TTL cache (Phase 1.4 perf optimization)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Caches the per-search-space "available connectors" + "available document
|
||||||
|
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
|
||||||
|
# ORM event listeners auto-invalidate on connector / document inserts,
|
||||||
|
# updates and deletes — the TTL only bounds staleness for bulk-import
|
||||||
|
# paths that bypass the ORM. Set to 0 to disable the cache.
|
||||||
|
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30
|
||||||
|
|
|
||||||
|
|
@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
|
||||||
COPY pyproject.toml .
|
COPY pyproject.toml .
|
||||||
COPY uv.lock .
|
COPY uv.lock .
|
||||||
|
|
||||||
# Install PyTorch based on architecture
|
# Install all Python dependencies from uv.lock for deterministic builds.
|
||||||
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
#
|
||||||
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
|
# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
|
||||||
else \
|
# which lets prod silently drift to newer upstream versions on every rebuild
|
||||||
pip install --no-cache-dir torch torchvision torchaudio; \
|
# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
|
||||||
fi
|
# Exporting the lock to requirements.txt and feeding it to `uv pip install`
|
||||||
|
# pins every transitive package to the exact version captured in uv.lock.
|
||||||
# Install python dependencies
|
#
|
||||||
|
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
|
||||||
|
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
|
||||||
|
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
|
||||||
|
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
|
||||||
|
# downloads that the lock-based install immediately replaced. If a specific
|
||||||
|
# CUDA version is needed (driver compatibility, etc.), wire it through
|
||||||
|
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
|
||||||
RUN pip install --no-cache-dir uv && \
|
RUN pip install --no-cache-dir uv && \
|
||||||
uv pip install --system --no-cache-dir -e .
|
uv export --frozen --no-dev --no-hashes --no-emit-project \
|
||||||
|
--format requirements-txt -o /tmp/requirements.txt && \
|
||||||
|
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
|
||||||
|
rm /tmp/requirements.txt
|
||||||
|
|
||||||
# Set SSL environment variables dynamically
|
# Set SSL environment variables dynamically
|
||||||
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
|
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
|
||||||
|
|
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
|
||||||
# Pre-download Docling models
|
# Pre-download Docling models
|
||||||
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
|
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
|
||||||
|
|
||||||
# Install Playwright browsers for web scraping if needed
|
# Install Playwright browsers for web scraping (the playwright package itself
|
||||||
RUN pip install playwright && \
|
# is already installed via uv.lock above)
|
||||||
playwright install chromium --with-deps
|
RUN playwright install chromium --with-deps
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
# Install the project itself in editable mode. Dependencies were already
|
||||||
|
# installed deterministically from uv.lock above, so --no-deps prevents any
|
||||||
|
# re-resolution that could pull newer versions.
|
||||||
|
RUN uv pip install --system --no-cache-dir --no-deps -e .
|
||||||
|
|
||||||
# Copy and set permissions for entrypoint script
|
# Copy and set permissions for entrypoint script
|
||||||
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
|
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
|
||||||
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
|
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
|
||||||
|
|
|
||||||
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
|
|
@ -0,0 +1,357 @@
|
||||||
|
"""TTL-LRU cache for compiled SurfSense deep agents.
|
||||||
|
|
||||||
|
Why this exists
|
||||||
|
---------------
|
||||||
|
|
||||||
|
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
|
||||||
|
turn:
|
||||||
|
|
||||||
|
1. Discover connectors & document types from Postgres (~50-200ms)
|
||||||
|
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
|
||||||
|
3. Compose the system prompt
|
||||||
|
4. Construct ~15 middleware instances (CPU)
|
||||||
|
5. Eagerly compile the general-purpose subagent
|
||||||
|
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
|
||||||
|
which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure
|
||||||
|
CPU work)
|
||||||
|
6. Compile the outer LangGraph
|
||||||
|
|
||||||
|
For a single thread, all six steps produce the SAME object on every turn
|
||||||
|
unless the user has changed their LLM config, toggled a feature flag,
|
||||||
|
added a connector, etc. The right answer is to compile ONCE per
|
||||||
|
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
|
||||||
|
every subsequent turn on the same thread.
|
||||||
|
|
||||||
|
Why a per-thread key (not a global pool)
|
||||||
|
----------------------------------------
|
||||||
|
|
||||||
|
Most middleware in the SurfSense stack captures per-thread state in
|
||||||
|
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
|
||||||
|
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
|
||||||
|
would silently leak state across users and threads. Keying the cache on
|
||||||
|
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
|
||||||
|
turns on the same thread without changing any middleware's behavior.
|
||||||
|
|
||||||
|
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
|
||||||
|
(read via ``runtime.context``) so the cache can collapse to a single
|
||||||
|
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
|
||||||
|
then, per-thread keying is the only safe option.
|
||||||
|
|
||||||
|
Cache shape
|
||||||
|
-----------
|
||||||
|
|
||||||
|
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
|
||||||
|
minutes — matches a typical chat session). ``maxsize`` (default 256)
|
||||||
|
caps memory; LRU evicts least-recently-used on overflow.
|
||||||
|
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
|
||||||
|
cold misses on the same key wait for the first build instead of
|
||||||
|
building N times.
|
||||||
|
* Process-local: this is an in-memory cache. Multi-replica deployments
|
||||||
|
pay the build cost once per replica per key. That's fine; the working
|
||||||
|
set per replica is small (one entry per active thread on that replica).
|
||||||
|
|
||||||
|
Telemetry
|
||||||
|
---------
|
||||||
|
|
||||||
|
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
|
||||||
|
|
||||||
|
* ``hit`` — cache hit, microseconds-fast
|
||||||
|
* ``miss`` — first build for this key, includes build duration
|
||||||
|
* ``stale`` — entry was found but expired; rebuilt
|
||||||
|
* ``evict`` — LRU eviction (size-limited)
|
||||||
|
* ``size`` — current cache occupancy at lookup time
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API: signature helpers (cache key components)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def stable_hash(*parts: Any) -> str:
|
||||||
|
"""Compute a deterministic SHA1 of the str repr of ``parts``.
|
||||||
|
|
||||||
|
Used for cache key components that need a fixed-width representation
|
||||||
|
(system prompt, tool list, etc.). SHA1 is fine here — this is not a
|
||||||
|
security boundary, just a content fingerprint.
|
||||||
|
"""
|
||||||
|
h = hashlib.sha1(usedforsecurity=False)
|
||||||
|
for p in parts:
|
||||||
|
h.update(repr(p).encode("utf-8", errors="replace"))
|
||||||
|
h.update(b"\x1f") # ASCII unit separator between parts
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def tools_signature(
|
||||||
|
tools: list[Any] | tuple[Any, ...],
|
||||||
|
*,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Hash the bound-tool surface for cache-key purposes.
|
||||||
|
|
||||||
|
The signature changes whenever:
|
||||||
|
|
||||||
|
* A tool is added or removed from the bound list (built-in toggles,
|
||||||
|
MCP tools loaded for the user changes, gating rules flip, etc.).
|
||||||
|
* The available connectors / document types for the search space
|
||||||
|
change (new connector added, last connector removed, new document
|
||||||
|
type indexed). Because :func:`get_connector_gated_tools` derives
|
||||||
|
``modified_disabled_tools`` from ``available_connectors``, the
|
||||||
|
tool surface is technically already covered — but we hash the
|
||||||
|
connector list separately so an empty-list "no tools changed"
|
||||||
|
situation still rotates the key when, say, the user re-adds a
|
||||||
|
connector that gates a tool we were already not exposing.
|
||||||
|
|
||||||
|
Stays stable across:
|
||||||
|
|
||||||
|
* Process restarts (tool names + descriptions are static).
|
||||||
|
* Different replicas (everyone gets the same hash for the same
|
||||||
|
inputs).
|
||||||
|
"""
|
||||||
|
tool_descriptors = sorted(
|
||||||
|
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
|
||||||
|
)
|
||||||
|
connectors = sorted(available_connectors or [])
|
||||||
|
doc_types = sorted(available_document_types or [])
|
||||||
|
return stable_hash(tool_descriptors, connectors, doc_types)
|
||||||
|
|
||||||
|
|
||||||
|
def flags_signature(flags: Any) -> str:
|
||||||
|
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
|
||||||
|
|
||||||
|
Frozen dataclasses are deterministically reprable, so a SHA1 of their
|
||||||
|
repr is a stable fingerprint. Restart safe (flags are read once at
|
||||||
|
process boot).
|
||||||
|
"""
|
||||||
|
return stable_hash(repr(flags))
|
||||||
|
|
||||||
|
|
||||||
|
def system_prompt_hash(system_prompt: str) -> str:
|
||||||
|
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
|
||||||
|
return hashlib.sha1(
|
||||||
|
system_prompt.encode("utf-8", errors="replace"),
|
||||||
|
usedforsecurity=False,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Cache implementation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Entry:
|
||||||
|
value: Any
|
||||||
|
created_at: float
|
||||||
|
last_used_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class _AgentCache:
|
||||||
|
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
|
||||||
|
|
||||||
|
NOT THREAD-SAFE in the multithreading sense — designed for a single
|
||||||
|
asyncio event loop. Uvicorn runs one event loop per worker process,
|
||||||
|
so this is fine; multi-worker deployments simply each maintain their
|
||||||
|
own cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
|
||||||
|
self._maxsize = maxsize
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
self._entries: OrderedDict[str, _Entry] = OrderedDict()
|
||||||
|
# One lock per key — guards "build" so concurrent cold misses on
|
||||||
|
# the same key wait for the first build instead of all racing.
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
def _now(self) -> float:
|
||||||
|
return time.monotonic()
|
||||||
|
|
||||||
|
def _is_fresh(self, entry: _Entry) -> bool:
|
||||||
|
return (self._now() - entry.created_at) < self._ttl
|
||||||
|
|
||||||
|
def _evict_if_full(self) -> None:
|
||||||
|
while len(self._entries) >= self._maxsize:
|
||||||
|
evicted_key, _ = self._entries.popitem(last=False)
|
||||||
|
self._locks.pop(evicted_key, None)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] evict key=%s reason=lru size=%d",
|
||||||
|
_short(evicted_key),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _touch(self, key: str, entry: _Entry) -> None:
|
||||||
|
entry.last_used_at = self._now()
|
||||||
|
self._entries.move_to_end(key, last=True)
|
||||||
|
|
||||||
|
async def get_or_build(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
*,
|
||||||
|
builder: Callable[[], Awaitable[Any]],
|
||||||
|
) -> Any:
|
||||||
|
"""Return the cached value for ``key`` or call ``builder()`` to make it.
|
||||||
|
|
||||||
|
``builder`` MUST be idempotent — concurrent cold misses on the
|
||||||
|
same key collapse to a single ``builder()`` call (the others
|
||||||
|
wait on the in-flight lock and observe the populated entry on
|
||||||
|
wake).
|
||||||
|
"""
|
||||||
|
# Fast path: hot hit.
|
||||||
|
entry = self._entries.get(key)
|
||||||
|
if entry is not None and self._is_fresh(entry):
|
||||||
|
self._touch(key, entry)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] hit key=%s age=%.1fs size=%d",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
# Stale entry — drop it; rebuild below.
|
||||||
|
if entry is not None and not self._is_fresh(entry):
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
self._ttl,
|
||||||
|
)
|
||||||
|
self._entries.pop(key, None)
|
||||||
|
|
||||||
|
# Slow path: serialize concurrent misses for the same key.
|
||||||
|
lock = self._locks.setdefault(key, asyncio.Lock())
|
||||||
|
async with lock:
|
||||||
|
# Double-check after acquiring the lock — another waiter may
|
||||||
|
# have populated the entry while we slept.
|
||||||
|
entry = self._entries.get(key)
|
||||||
|
if entry is not None and self._is_fresh(entry):
|
||||||
|
self._touch(key, entry)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
value = await builder()
|
||||||
|
except BaseException:
|
||||||
|
# Don't cache failed builds; let the next caller retry.
|
||||||
|
_perf_log.warning(
|
||||||
|
"[agent_cache] build_failed key=%s elapsed=%.3fs",
|
||||||
|
_short(key),
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
|
||||||
|
# Insert + evict.
|
||||||
|
self._evict_if_full()
|
||||||
|
now = self._now()
|
||||||
|
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
|
||||||
|
self._entries.move_to_end(key, last=True)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] miss key=%s build=%.3fs size=%d",
|
||||||
|
_short(key),
|
||||||
|
elapsed,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def invalidate(self, key: str) -> bool:
|
||||||
|
"""Drop a single entry; return True if anything was removed."""
|
||||||
|
removed = self._entries.pop(key, None) is not None
|
||||||
|
self._locks.pop(key, None)
|
||||||
|
if removed:
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] invalidate key=%s size=%d",
|
||||||
|
_short(key),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def invalidate_prefix(self, prefix: str) -> int:
|
||||||
|
"""Drop every entry whose key starts with ``prefix``. Returns count."""
|
||||||
|
keys = [k for k in self._entries if k.startswith(prefix)]
|
||||||
|
for k in keys:
|
||||||
|
self._entries.pop(k, None)
|
||||||
|
self._locks.pop(k, None)
|
||||||
|
if keys:
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
|
||||||
|
_short(prefix),
|
||||||
|
len(keys),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return len(keys)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
n = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
self._locks.clear()
|
||||||
|
if n:
|
||||||
|
_perf_log.info("[agent_cache] clear removed=%d", n)
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"size": len(self._entries),
|
||||||
|
"maxsize": self._maxsize,
|
||||||
|
"ttl_seconds": self._ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _short(key: str, n: int = 16) -> str:
|
||||||
|
"""Truncate keys for log lines so they don't blow up log volume."""
|
||||||
|
return key if len(key) <= n else f"{key[:n]}..."
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
|
||||||
|
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
|
||||||
|
|
||||||
|
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache() -> _AgentCache:
|
||||||
|
"""Return the process-wide compiled-agent cache singleton."""
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
|
||||||
|
"""Replace the singleton with a fresh cache. Tests only."""
|
||||||
|
global _cache
|
||||||
|
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"flags_signature",
|
||||||
|
"get_cache",
|
||||||
|
"reload_for_tests",
|
||||||
|
"stable_hash",
|
||||||
|
"system_prompt_hash",
|
||||||
|
"tools_signature",
|
||||||
|
]
|
||||||
|
|
@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.agent_cache import (
|
||||||
|
flags_signature,
|
||||||
|
get_cache,
|
||||||
|
stable_hash,
|
||||||
|
system_prompt_hash,
|
||||||
|
tools_signature,
|
||||||
|
)
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
|
|
@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
DoomLoopMiddleware,
|
DoomLoopMiddleware,
|
||||||
FileIntentMiddleware,
|
FileIntentMiddleware,
|
||||||
|
FlattenSystemMessageMiddleware,
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
KnowledgePriorityMiddleware,
|
KnowledgePriorityMiddleware,
|
||||||
KnowledgeTreeMiddleware,
|
KnowledgeTreeMiddleware,
|
||||||
|
|
@ -330,23 +338,39 @@ async def create_surfsense_deep_agent(
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space.
|
||||||
|
#
|
||||||
|
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
|
||||||
|
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
|
||||||
|
# SQLAlchemy explicitly forbids concurrent operations on the same session
|
||||||
|
# ("This session is provisioning a new connection; concurrent operations
|
||||||
|
# are not permitted on the same session"). The Phase 1.4 in-process TTL
|
||||||
|
# cache in ``connector_service`` already collapses the warm path to a
|
||||||
|
# near-zero pair of dict lookups, so sequential awaits cost nothing in
|
||||||
|
# the common case while remaining correct on cold cache misses.
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
available_document_types: list[str] | None = None
|
available_document_types: list[str] | None = None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
connector_types = await connector_service.get_available_connectors(
|
try:
|
||||||
|
connector_types_result = await connector_service.get_available_connectors(
|
||||||
search_space_id
|
search_space_id
|
||||||
)
|
)
|
||||||
if connector_types:
|
if connector_types_result:
|
||||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
available_connectors = _map_connectors_to_searchable_types(
|
||||||
|
connector_types_result
|
||||||
available_document_types = await connector_service.get_available_document_types(
|
|
||||||
search_space_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.warning("Failed to discover available connectors: %s", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
available_document_types = (
|
||||||
|
await connector_service.get_available_document_types(search_space_id)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to discover available document types: %s", e)
|
||||||
|
except Exception as e: # pragma: no cover - defensive outer guard
|
||||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||||
|
|
@ -469,8 +493,16 @@ async def create_surfsense_deep_agent(
|
||||||
# entire middleware build + main-graph compile into a single
|
# entire middleware build + main-graph compile into a single
|
||||||
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
||||||
# event loop stays responsive.
|
# event loop stays responsive.
|
||||||
_t0 = time.perf_counter()
|
#
|
||||||
agent = await asyncio.to_thread(
|
# PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
|
||||||
|
# on every per-request value that any middleware in the stack closes
|
||||||
|
# over in ``__init__`` — drop one and you risk leaking state across
|
||||||
|
# threads. Hits collapse this whole block to a microsecond lookup;
|
||||||
|
# misses pay the original CPU cost AND populate the cache.
|
||||||
|
config_id = agent_config.config_id if agent_config is not None else None
|
||||||
|
|
||||||
|
async def _build_agent() -> Any:
|
||||||
|
return await asyncio.to_thread(
|
||||||
_build_compiled_agent_blocking,
|
_build_compiled_agent_blocking,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|
@ -484,14 +516,54 @@ async def create_surfsense_deep_agent(
|
||||||
anon_session_id=anon_session_id,
|
anon_session_id=anon_session_id,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
available_document_types=available_document_types,
|
available_document_types=available_document_types,
|
||||||
|
# ``mentioned_document_ids`` is consumed by
|
||||||
|
# ``KnowledgePriorityMiddleware`` per turn via
|
||||||
|
# ``runtime.context`` (Phase 1.5). We still pass the
|
||||||
|
# caller-provided list here for the legacy fallback path
|
||||||
|
# (cache disabled / context not propagated) — the middleware
|
||||||
|
# drains its own copy after the first read so a cached graph
|
||||||
|
# never replays stale mentions.
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
max_input_tokens=_max_input_tokens,
|
max_input_tokens=_max_input_tokens,
|
||||||
flags=_flags,
|
flags=_flags,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
|
||||||
|
# Cache key components — order matters only for human readability;
|
||||||
|
# the resulting hash is what's stored. Every component must
|
||||||
|
# rotate on a real shape change AND stay stable across identical
|
||||||
|
# invocations.
|
||||||
|
cache_key = stable_hash(
|
||||||
|
"v1", # schema version of the key — bump if components change
|
||||||
|
config_id,
|
||||||
|
thread_id,
|
||||||
|
user_id,
|
||||||
|
search_space_id,
|
||||||
|
visibility,
|
||||||
|
filesystem_selection.mode,
|
||||||
|
anon_session_id,
|
||||||
|
tools_signature(
|
||||||
|
tools,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
),
|
||||||
|
flags_signature(_flags),
|
||||||
|
system_prompt_hash(final_system_prompt),
|
||||||
|
_max_input_tokens,
|
||||||
|
# ``mentioned_document_ids`` deliberately omitted — middleware
|
||||||
|
# reads it from ``runtime.context`` (Phase 1.5).
|
||||||
|
)
|
||||||
|
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
|
||||||
|
else:
|
||||||
|
agent = await _build_agent()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
"[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
|
"on"
|
||||||
|
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
|
||||||
|
else "off",
|
||||||
)
|
)
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking(
|
||||||
noop_mw,
|
noop_mw,
|
||||||
retry_mw,
|
retry_mw,
|
||||||
fallback_mw,
|
fallback_mw,
|
||||||
|
# Coalesce a multi-text-block system message into one block
|
||||||
|
# immediately before the model call. Sits innermost on the
|
||||||
|
# system-message-mutation chain so it observes every appender
|
||||||
|
# (todo / filesystem / skills / subagents …) and prevents
|
||||||
|
# OpenRouter→Anthropic from redistributing ``cache_control``
|
||||||
|
# across N blocks and tripping Anthropic's 4-breakpoint cap.
|
||||||
|
# See ``middleware/flatten_system.py`` for full rationale.
|
||||||
|
FlattenSystemMessageMiddleware(),
|
||||||
# Tool-call repair must run after model emits but before
|
# Tool-call repair must run after model emits but before
|
||||||
# permission / dedup / doom-loop interpret the calls.
|
# permission / dedup / doom-loop interpret the calls.
|
||||||
repair_mw,
|
repair_mw,
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,25 @@
|
||||||
"""
|
"""
|
||||||
Context schema definitions for SurfSense agents.
|
Context schema definitions for SurfSense agents.
|
||||||
|
|
||||||
This module defines the custom state schema used by the SurfSense deep agent.
|
This module defines the per-invocation context object passed to the SurfSense
|
||||||
|
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
|
||||||
|
|
||||||
|
The agent's compiled graph is the same across invocations (and cached by
|
||||||
|
``agent_cache``), so anything that varies per turn — the user mentions a
|
||||||
|
specific document, the front-end issues a unique ``request_id``, etc. —
|
||||||
|
MUST live on this context object instead of being captured into a
|
||||||
|
middleware ``__init__`` closure. Middlewares read fields back via
|
||||||
|
``runtime.context.<field>``; tools read them via ``runtime.context``.
|
||||||
|
|
||||||
|
This object is read inside both ``KnowledgePriorityMiddleware`` (for
|
||||||
|
``mentioned_document_ids``) and any future middleware that needs
|
||||||
|
per-request state without invalidating the compiled-agent cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import NotRequired, TypedDict
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
|
||||||
class FileOperationContractState(TypedDict):
|
class FileOperationContractState(TypedDict):
|
||||||
|
|
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseContextSchema(TypedDict):
|
@dataclass
|
||||||
|
class SurfSenseContextSchema:
|
||||||
"""
|
"""
|
||||||
Custom state schema for the SurfSense deep agent.
|
Per-invocation context for the SurfSense deep agent.
|
||||||
|
|
||||||
This extends the default agent state with custom fields.
|
Defaults are chosen so the dataclass can be safely default-constructed
|
||||||
The default state already includes:
|
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
|
||||||
- messages: Conversation history
|
context is supplied — see ``langgraph.runtime.Runtime``). All fields
|
||||||
- todos: Task list from TodoListMiddleware
|
are optional; consumers must None-check before reading.
|
||||||
- files: Virtual filesystem from FilesystemMiddleware
|
|
||||||
|
|
||||||
We're adding fields needed for knowledge base search:
|
Phase 1.5 fields:
|
||||||
- search_space_id: The user's search space ID
|
search_space_id: Search space the request is scoped to.
|
||||||
- db_session: Database session (injected at runtime)
|
mentioned_document_ids: KB documents the user @-mentioned this turn.
|
||||||
- connector_service: Connector service instance (injected at runtime)
|
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||||
|
list. Stays out of the compiled-agent cache key — that's the
|
||||||
|
whole point of putting it here.
|
||||||
|
file_operation_contract: One-shot file operation contract emitted
|
||||||
|
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||||
|
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||||
|
task; populated for telemetry.
|
||||||
|
|
||||||
|
Phase 2 will extend with: thread_id, user_id, visibility,
|
||||||
|
filesystem_mode, anon_session_id, available_connectors,
|
||||||
|
available_document_types, created_by_id (everything currently captured
|
||||||
|
by middleware ``__init__`` closures).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int | None = None
|
||||||
file_operation_contract: NotRequired[FileOperationContractState]
|
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||||
turn_id: NotRequired[str]
|
file_operation_contract: FileOperationContractState | None = None
|
||||||
request_id: NotRequired[str]
|
turn_id: str | None = None
|
||||||
# These are runtime-injected and won't be serialized
|
request_id: str | None = None
|
||||||
# db_session and connector_service are passed when invoking the agent
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
|
||||||
|
|
||||||
These flags gate the newer agent middleware (some ported from OpenCode,
|
These flags gate the newer agent middleware (some ported from OpenCode,
|
||||||
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
||||||
SurfSense-native). They follow a "default-OFF for risky things,
|
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
||||||
default-ON for safe upgrades, master kill-switch for everything new" model.
|
image updates work even when older installs do not have newly introduced
|
||||||
|
environment variables. Risky/experimental integrations stay default OFF,
|
||||||
|
and the master kill-switch can still disable everything new.
|
||||||
|
|
||||||
All new middleware checks its flag at agent build time. If the master
|
All new middleware checks its flag at agent build time. If the master
|
||||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||||
|
|
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
|
||||||
Local development (recommended for trying everything except doom-loop / selector):
|
Defaults:
|
||||||
|
|
||||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
|
@ -60,32 +65,28 @@ class AgentFeatureFlags:
|
||||||
disable_new_agent_stack: bool = False
|
disable_new_agent_stack: bool = False
|
||||||
|
|
||||||
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
||||||
enable_context_editing: bool = False
|
enable_context_editing: bool = True
|
||||||
enable_compaction_v2: bool = False
|
enable_compaction_v2: bool = True
|
||||||
enable_retry_after: bool = False
|
enable_retry_after: bool = True
|
||||||
enable_model_fallback: bool = False
|
enable_model_fallback: bool = False
|
||||||
enable_model_call_limit: bool = False
|
enable_model_call_limit: bool = True
|
||||||
enable_tool_call_limit: bool = False
|
enable_tool_call_limit: bool = True
|
||||||
enable_tool_call_repair: bool = False
|
enable_tool_call_repair: bool = True
|
||||||
enable_doom_loop: bool = (
|
enable_doom_loop: bool = True
|
||||||
False # Default OFF until UI handles permission='doom_loop'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safety — permissions, concurrency, tool-set narrowing
|
# Safety — permissions, concurrency, tool-set narrowing
|
||||||
enable_permission: bool = False # Default OFF for first deploy
|
enable_permission: bool = True
|
||||||
enable_busy_mutex: bool = False
|
enable_busy_mutex: bool = True
|
||||||
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||||
|
|
||||||
# Skills + subagents
|
# Skills + subagents
|
||||||
enable_skills: bool = False
|
enable_skills: bool = True
|
||||||
enable_specialized_subagents: bool = False
|
enable_specialized_subagents: bool = True
|
||||||
enable_kb_planner_runnable: bool = False
|
enable_kb_planner_runnable: bool = True
|
||||||
|
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log: bool = False
|
enable_action_log: bool = True
|
||||||
enable_revert_route: bool = (
|
enable_revert_route: bool = True
|
||||||
False # Backend ships before UI; route returns 503 until this flips
|
|
||||||
)
|
|
||||||
|
|
||||||
# Streaming parity v2 — opt in to LangChain's structured
|
# Streaming parity v2 — opt in to LangChain's structured
|
||||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||||
|
|
@ -94,7 +95,7 @@ class AgentFeatureFlags:
|
||||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||||
# ship unconditionally because they're forward-compatible.
|
# ship unconditionally because they're forward-compatible.
|
||||||
enable_stream_parity_v2: bool = False
|
enable_stream_parity_v2: bool = True
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
@ -102,6 +103,41 @@ class AgentFeatureFlags:
|
||||||
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||||
enable_otel: bool = False
|
enable_otel: bool = False
|
||||||
|
|
||||||
|
# Performance — compiled-agent cache (Phase 1 + Phase 2).
|
||||||
|
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
|
||||||
|
# graph if the cache key matches (LLM config + thread + tool surface +
|
||||||
|
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
|
||||||
|
# wall clock from ~4-5s to <50µs on cache hits.
|
||||||
|
#
|
||||||
|
# SAFETY (Phase 2 unblocked this default-on):
|
||||||
|
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
|
||||||
|
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
|
||||||
|
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
|
||||||
|
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
|
||||||
|
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
|
||||||
|
# short-lived ``AsyncSession`` instances per call via
|
||||||
|
# :data:`async_session_maker`. The factory still accepts ``db_session``
|
||||||
|
# for registry compatibility but ``del``'s it immediately — see any
|
||||||
|
# of those files' factory docstrings for the rationale. The ``llm``
|
||||||
|
# closure is per-(provider, model, config_id) which is already in
|
||||||
|
# the cache key, so the LLM is safe to share across cached hits of
|
||||||
|
# the same key. The KB priority middleware reads
|
||||||
|
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
|
||||||
|
# not its constructor closure, so the same compiled agent serves
|
||||||
|
# turns with different mention lists correctly.
|
||||||
|
#
|
||||||
|
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
|
||||||
|
# environment if a regression surfaces. The path is exercised by
|
||||||
|
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
|
||||||
|
enable_agent_cache: bool = True
|
||||||
|
# Phase 1 (deferred — measure first): pre-build & share the
|
||||||
|
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
|
||||||
|
# misses. Only helps when the outer cache MISSES (cache hits already
|
||||||
|
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
|
||||||
|
# until we have data showing cold misses are frequent enough to
|
||||||
|
# justify the extra global state.
|
||||||
|
enable_agent_cache_share_gp_subagent: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls) -> AgentFeatureFlags:
|
def from_env(cls) -> AgentFeatureFlags:
|
||||||
"""Read flags from environment.
|
"""Read flags from environment.
|
||||||
|
|
@ -115,48 +151,76 @@ class AgentFeatureFlags:
|
||||||
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||||
"middleware is forced OFF for this build."
|
"middleware is forced OFF for this build."
|
||||||
)
|
)
|
||||||
return cls(disable_new_agent_stack=True)
|
return cls(
|
||||||
|
disable_new_agent_stack=True,
|
||||||
|
enable_context_editing=False,
|
||||||
|
enable_compaction_v2=False,
|
||||||
|
enable_retry_after=False,
|
||||||
|
enable_model_fallback=False,
|
||||||
|
enable_model_call_limit=False,
|
||||||
|
enable_tool_call_limit=False,
|
||||||
|
enable_tool_call_repair=False,
|
||||||
|
enable_doom_loop=False,
|
||||||
|
enable_permission=False,
|
||||||
|
enable_busy_mutex=False,
|
||||||
|
enable_llm_tool_selector=False,
|
||||||
|
enable_skills=False,
|
||||||
|
enable_specialized_subagents=False,
|
||||||
|
enable_kb_planner_runnable=False,
|
||||||
|
enable_action_log=False,
|
||||||
|
enable_revert_route=False,
|
||||||
|
enable_stream_parity_v2=False,
|
||||||
|
enable_plugin_loader=False,
|
||||||
|
enable_otel=False,
|
||||||
|
enable_agent_cache=False,
|
||||||
|
enable_agent_cache_share_gp_subagent=False,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
disable_new_agent_stack=False,
|
disable_new_agent_stack=False,
|
||||||
# Agent quality
|
# Agent quality
|
||||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
|
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
|
||||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
|
||||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
|
||||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||||
enable_model_call_limit=_env_bool(
|
enable_model_call_limit=_env_bool(
|
||||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
|
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
|
||||||
),
|
),
|
||||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
|
||||||
enable_tool_call_repair=_env_bool(
|
enable_tool_call_repair=_env_bool(
|
||||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
|
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
|
||||||
),
|
),
|
||||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
|
||||||
# Safety
|
# Safety
|
||||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
|
||||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
|
||||||
enable_llm_tool_selector=_env_bool(
|
enable_llm_tool_selector=_env_bool(
|
||||||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||||
),
|
),
|
||||||
# Skills + subagents
|
# Skills + subagents
|
||||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
|
||||||
enable_specialized_subagents=_env_bool(
|
enable_specialized_subagents=_env_bool(
|
||||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
|
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
|
||||||
),
|
),
|
||||||
enable_kb_planner_runnable=_env_bool(
|
enable_kb_planner_runnable=_env_bool(
|
||||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
|
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||||
),
|
),
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||||
# Streaming parity v2
|
# Streaming parity v2
|
||||||
enable_stream_parity_v2=_env_bool(
|
enable_stream_parity_v2=_env_bool(
|
||||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||||
),
|
),
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
# Observability
|
# Observability
|
||||||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||||
|
# Performance
|
||||||
|
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
|
||||||
|
enable_agent_cache_share_gp_subagent=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def any_new_middleware_enabled(self) -> bool:
|
def any_new_middleware_enabled(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction
|
# Provider mapping for LiteLLM model string construction.
|
||||||
PROVIDER_MAP = {
|
#
|
||||||
"OPENAI": "openai",
|
# Single source of truth lives in
|
||||||
"ANTHROPIC": "anthropic",
|
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||||
"GROQ": "groq",
|
# runs during ``app.config`` class-body init) can resolve provider
|
||||||
"COHERE": "cohere",
|
# prefixes without dragging the agent / tools tree into module load
|
||||||
"GOOGLE": "gemini",
|
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||||
"OLLAMA": "ollama_chat",
|
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||||
"MISTRAL": "mistral",
|
# tests) keep working unchanged.
|
||||||
"AZURE_OPENAI": "azure",
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
"OPENROUTER": "openrouter",
|
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||||
"XAI": "xai",
|
)
|
||||||
"BEDROCK": "bedrock",
|
|
||||||
"VERTEX_AI": "vertex_ai",
|
|
||||||
"TOGETHER_AI": "together_ai",
|
|
||||||
"FIREWORKS_AI": "fireworks_ai",
|
|
||||||
"DEEPSEEK": "openai",
|
|
||||||
"ALIBABA_QWEN": "openai",
|
|
||||||
"MOONSHOT": "openai",
|
|
||||||
"ZHIPU": "openai",
|
|
||||||
"GITHUB_MODELS": "github",
|
|
||||||
"REPLICATE": "replicate",
|
|
||||||
"PERPLEXITY": "perplexity",
|
|
||||||
"ANYSCALE": "anyscale",
|
|
||||||
"DEEPINFRA": "deepinfra",
|
|
||||||
"CEREBRAS": "cerebras",
|
|
||||||
"SAMBANOVA": "sambanova",
|
|
||||||
"AI21": "ai21",
|
|
||||||
"CLOUDFLARE": "cloudflare",
|
|
||||||
"DATABRICKS": "databricks",
|
|
||||||
"COMETAPI": "cometapi",
|
|
||||||
"HUGGINGFACE": "huggingface",
|
|
||||||
"MINIMAX": "openai",
|
|
||||||
"CUSTOM": "custom",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||||
|
|
@ -178,6 +155,17 @@ class AgentConfig:
|
||||||
anonymous_enabled: bool = False
|
anonymous_enabled: bool = False
|
||||||
quota_reserve_tokens: int | None = None
|
quota_reserve_tokens: int | None = None
|
||||||
|
|
||||||
|
# Capability flag: best-effort True for the chat selector / catalog.
|
||||||
|
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||||
|
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||||
|
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||||
|
# is the conservative-allow stance — the streaming-task safety net
|
||||||
|
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||||
|
# actually blocks a request. Setting this to False here without an
|
||||||
|
# authoritative source would silently hide vision-capable models
|
||||||
|
# (the regression we're fixing).
|
||||||
|
supports_image_input: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_auto_mode(cls) -> "AgentConfig":
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
"""
|
"""
|
||||||
|
|
@ -203,6 +191,12 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# Auto routes across the configured pool, which usually
|
||||||
|
# contains at least one vision-capable deployment; the router
|
||||||
|
# will surface a 404 from a non-vision deployment as a normal
|
||||||
|
# ``allowed_fails`` event and fail over rather than blocking
|
||||||
|
# the request outright.
|
||||||
|
supports_image_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -216,10 +210,24 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
return cls(
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
provider=config.provider.value
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
provider_value = (
|
||||||
|
config.provider.value
|
||||||
if hasattr(config.provider, "value")
|
if hasattr(config.provider, "value")
|
||||||
else str(config.provider),
|
else str(config.provider)
|
||||||
|
)
|
||||||
|
litellm_params = config.litellm_params or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
provider=provider_value,
|
||||||
model_name=config.model_name,
|
model_name=config.model_name,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
api_base=config.api_base,
|
api_base=config.api_base,
|
||||||
|
|
@ -235,6 +243,16 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# BYOK rows have no operator-curated capability flag, so we
|
||||||
|
# ask LiteLLM (default-allow on unknown). The streaming
|
||||||
|
# safety net still blocks if the model is *explicitly*
|
||||||
|
# marked text-only.
|
||||||
|
supports_image_input=derive_supports_image_input(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -253,15 +271,46 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
# Get system instructions from YAML, default to empty string
|
# Get system instructions from YAML, default to empty string
|
||||||
system_instructions = yaml_config.get("system_instructions", "")
|
system_instructions = yaml_config.get("system_instructions", "")
|
||||||
|
|
||||||
|
provider = yaml_config.get("provider", "").upper()
|
||||||
|
model_name = yaml_config.get("model_name", "")
|
||||||
|
custom_provider = yaml_config.get("custom_provider")
|
||||||
|
litellm_params = yaml_config.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||||
|
# OpenRouter modalities. The YAML loader already populates this
|
||||||
|
# field, but this method is also called from
|
||||||
|
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||||
|
# so we re-derive here for safety. The bool() coercion preserves
|
||||||
|
# the loader's behaviour for explicit ``true`` / ``false``
|
||||||
|
# strings that PyYAML may surface.
|
||||||
|
if "supports_image_input" in yaml_config:
|
||||||
|
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||||
|
else:
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
provider=yaml_config.get("provider", "").upper(),
|
provider=provider,
|
||||||
model_name=yaml_config.get("model_name", ""),
|
model_name=model_name,
|
||||||
api_key=yaml_config.get("api_key", ""),
|
api_key=yaml_config.get("api_key", ""),
|
||||||
api_base=yaml_config.get("api_base"),
|
api_base=yaml_config.get("api_base"),
|
||||||
custom_provider=yaml_config.get("custom_provider"),
|
custom_provider=custom_provider,
|
||||||
litellm_params=yaml_config.get("litellm_params"),
|
litellm_params=yaml_config.get("litellm_params"),
|
||||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||||
system_instructions=system_instructions if system_instructions else None,
|
system_instructions=system_instructions if system_instructions else None,
|
||||||
|
|
@ -276,6 +325,7 @@ class AgentConfig:
|
||||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||||
|
supports_image_input=supports_image_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
|
||||||
from app.agents.new_chat.middleware.filesystem import (
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.flatten_system import (
|
||||||
|
FlattenSystemMessageMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.kb_persistence import (
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
commit_staged_filesystem_state,
|
commit_staged_filesystem_state,
|
||||||
|
|
@ -61,6 +64,7 @@ __all__ = [
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
"DoomLoopMiddleware",
|
"DoomLoopMiddleware",
|
||||||
"FileIntentMiddleware",
|
"FileIntentMiddleware",
|
||||||
|
"FlattenSystemMessageMiddleware",
|
||||||
"KnowledgeBasePersistenceMiddleware",
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
"KnowledgePriorityMiddleware",
|
"KnowledgePriorityMiddleware",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,233 @@
|
||||||
|
r"""Coalesce multi-block system messages into a single text block.
|
||||||
|
|
||||||
|
Several middlewares in our deepagent stack each call
|
||||||
|
``append_to_system_message`` on the way down to the model
|
||||||
|
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||||
|
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||||
|
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||||
|
|
||||||
|
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||||
|
request**, and we configure 2 injection points
|
||||||
|
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||||
|
the prepended ``request.system_message``, this middleware is the
|
||||||
|
defensive partner: it guarantees that "the system block" is *one*
|
||||||
|
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||||
|
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||||
|
several breakpoints by spreading ``cache_control`` across multiple
|
||||||
|
text blocks of a multi-block system content.
|
||||||
|
|
||||||
|
Without flattening we used to see::
|
||||||
|
|
||||||
|
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||||
|
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||||
|
cache_control may be provided. Found 5."}}}
|
||||||
|
|
||||||
|
(Same error class documented in
|
||||||
|
https://github.com/BerriAI/litellm/issues/15696 and
|
||||||
|
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||||
|
in PR #15395 covers the litellm transformer but does not protect us
|
||||||
|
when the OpenRouter SaaS itself does the redistribution.)
|
||||||
|
|
||||||
|
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
|
||||||
|
the first injection point from ``role: system`` to ``index: 0``)
|
||||||
|
neutralises the *primary* cause of the same 400 — multiple
|
||||||
|
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||||
|
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||||
|
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||||
|
matcher. This middleware remains useful as defence-in-depth against
|
||||||
|
the multi-block redistribution path.
|
||||||
|
|
||||||
|
Placement: innermost on the system-message-mutation chain, after every
|
||||||
|
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||||
|
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||||
|
attempt sees a flattened payload. See ``chat_deepagent.py``.
|
||||||
|
|
||||||
|
Idempotent: a string-content system message is left untouched. A list
|
||||||
|
that contains anything other than plain text blocks (e.g. an image) is
|
||||||
|
also left untouched — those are rare on system messages and we'd lose
|
||||||
|
the non-text payload by joining.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||||
|
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||||
|
|
||||||
|
Returns ``None`` when the list contains anything that isn't a text
|
||||||
|
block we can safely concatenate (image, audio, file, non-standard
|
||||||
|
blocks, dicts with extra non-cache_control fields). The caller
|
||||||
|
leaves the original content untouched in that case rather than
|
||||||
|
silently dropping payload.
|
||||||
|
|
||||||
|
``cache_control`` on individual blocks is intentionally discarded —
|
||||||
|
the whole point of flattening is to let LiteLLM's
|
||||||
|
``cache_control_injection_points`` re-place a single breakpoint on
|
||||||
|
the resulting one-block system content.
|
||||||
|
"""
|
||||||
|
chunks: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
chunks.append(block)
|
||||||
|
continue
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
return None
|
||||||
|
if block.get("type") != "text":
|
||||||
|
return None
|
||||||
|
text = block.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return None
|
||||||
|
chunks.append(text)
|
||||||
|
return "\n\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _flattened_request(
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
) -> ModelRequest[ContextT] | None:
|
||||||
|
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||||
|
sys_msg = request.system_message
|
||||||
|
if sys_msg is None:
|
||||||
|
return None
|
||||||
|
content = sys_msg.content
|
||||||
|
if not isinstance(content, list) or len(content) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
flattened = _flatten_text_blocks(content)
|
||||||
|
if flattened is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_sys = SystemMessage(
|
||||||
|
content=flattened,
|
||||||
|
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||||
|
response_metadata=dict(sys_msg.response_metadata),
|
||||||
|
)
|
||||||
|
if sys_msg.id is not None:
|
||||||
|
new_sys.id = sys_msg.id
|
||||||
|
return request.override(system_message=new_sys)
|
||||||
|
|
||||||
|
|
||||||
|
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||||
|
"""One-line dump of cache_control-relevant request shape.
|
||||||
|
|
||||||
|
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||||
|
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||||
|
root cause is confirmed and a fix is in place.
|
||||||
|
"""
|
||||||
|
sys_msg = request.system_message
|
||||||
|
if sys_msg is None:
|
||||||
|
sys_shape = "none"
|
||||||
|
elif isinstance(sys_msg.content, str):
|
||||||
|
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||||
|
elif isinstance(sys_msg.content, list):
|
||||||
|
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||||
|
else:
|
||||||
|
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||||
|
|
||||||
|
role_hist: list[str] = []
|
||||||
|
multi_block_msgs = 0
|
||||||
|
msgs_with_cc = 0
|
||||||
|
sys_msgs_in_history = 0
|
||||||
|
for m in request.messages:
|
||||||
|
mtype = getattr(m, "type", type(m).__name__)
|
||||||
|
role_hist.append(mtype)
|
||||||
|
if isinstance(m, SystemMessage):
|
||||||
|
sys_msgs_in_history += 1
|
||||||
|
c = getattr(m, "content", None)
|
||||||
|
if isinstance(c, list):
|
||||||
|
multi_block_msgs += 1
|
||||||
|
for blk in c:
|
||||||
|
if isinstance(blk, dict) and "cache_control" in blk:
|
||||||
|
msgs_with_cc += 1
|
||||||
|
break
|
||||||
|
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||||
|
msgs_with_cc += 1
|
||||||
|
|
||||||
|
tools = request.tools or []
|
||||||
|
tools_with_cc = 0
|
||||||
|
for t in tools:
|
||||||
|
if isinstance(t, dict) and (
|
||||||
|
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||||
|
):
|
||||||
|
tools_with_cc += 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||||
|
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||||
|
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||||
|
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||||
|
f"roles={role_hist[-8:]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenSystemMessageMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
|
"""Collapse a multi-text-block system message to a single string.
|
||||||
|
|
||||||
|
Sits innermost on the system-message-mutation chain so it observes
|
||||||
|
every middleware's contribution. Has no other side effect — the
|
||||||
|
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||||
|
) -> Any:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
if flattened is not None:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(
|
||||||
|
"[flatten_system] collapsed %d system blocks to one",
|
||||||
|
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||||
|
)
|
||||||
|
return handler(flattened)
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
|
) -> Any:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
if flattened is not None:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(
|
||||||
|
"[flatten_system] collapsed %d system blocks to one",
|
||||||
|
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||||
|
)
|
||||||
|
return await handler(flattened)
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FlattenSystemMessageMiddleware",
|
||||||
|
"_flatten_text_blocks",
|
||||||
|
"_flattened_request",
|
||||||
|
]
|
||||||
|
|
@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
|
||||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if anon_doc:
|
if anon_doc:
|
||||||
return self._anon_priority(state, anon_doc)
|
return self._anon_priority(state, anon_doc)
|
||||||
|
|
||||||
return await self._authenticated_priority(state, messages, user_text)
|
return await self._authenticated_priority(state, messages, user_text, runtime)
|
||||||
|
|
||||||
def _anon_priority(
|
def _anon_priority(
|
||||||
self,
|
self,
|
||||||
|
|
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
|
runtime: Runtime[Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
t0 = asyncio.get_event_loop().time()
|
t0 = asyncio.get_event_loop().time()
|
||||||
(
|
(
|
||||||
|
|
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
mentioned_results: list[dict[str, Any]] = []
|
# Per-turn ``mentioned_document_ids`` flow:
|
||||||
|
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
||||||
|
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
||||||
|
# on every ``astream_events`` call, so this list is naturally
|
||||||
|
# scoped to the current turn. Allows cross-turn graph reuse via
|
||||||
|
# ``agent_cache``.
|
||||||
|
# 2. Legacy fallback (cache disabled / context not propagated): the
|
||||||
|
# constructor-injected ``self.mentioned_document_ids`` list. We
|
||||||
|
# drain it after the first read so a cached graph (no Phase 1.5
|
||||||
|
# wiring) doesn't keep replaying the same mentions on every
|
||||||
|
# turn.
|
||||||
|
#
|
||||||
|
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
||||||
|
# all) from "context provided but empty" (turn with no mentions).
|
||||||
|
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
||||||
|
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
||||||
|
# legacy closure on every no-mention follow-up turn — replaying the
|
||||||
|
# mentions baked in by turn 1's cache-miss build. Always drain the
|
||||||
|
# closure once the runtime path has fired so a cached middleware
|
||||||
|
# instance can never resurrect stale state.
|
||||||
|
mention_ids: list[int] = []
|
||||||
|
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
||||||
|
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
||||||
|
if ctx_mentions is not None:
|
||||||
|
# Runtime path is authoritative — even an empty list means
|
||||||
|
# "this turn has no mentions", NOT "look at the closure".
|
||||||
|
mention_ids = list(ctx_mentions)
|
||||||
if self.mentioned_document_ids:
|
if self.mentioned_document_ids:
|
||||||
|
self.mentioned_document_ids = []
|
||||||
|
elif self.mentioned_document_ids:
|
||||||
|
mention_ids = list(self.mentioned_document_ids)
|
||||||
|
self.mentioned_document_ids = []
|
||||||
|
|
||||||
|
mentioned_results: list[dict[str, Any]] = []
|
||||||
|
if mention_ids:
|
||||||
mentioned_results = await fetch_mentioned_documents(
|
mentioned_results = await fetch_mentioned_documents(
|
||||||
document_ids=self.mentioned_document_ids,
|
document_ids=mention_ids,
|
||||||
search_space_id=self.search_space_id,
|
search_space_id=self.search_space_id,
|
||||||
)
|
)
|
||||||
self.mentioned_document_ids = []
|
|
||||||
|
|
||||||
if is_recency:
|
if is_recency:
|
||||||
doc_types = _resolve_search_types(
|
doc_types = _resolve_search_types(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||||
|
|
||||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||||
|
|
@ -17,8 +17,20 @@ Coverage:
|
||||||
|
|
||||||
We inject **two** breakpoints per request:
|
We inject **two** breakpoints per request:
|
||||||
|
|
||||||
- ``role: system`` — pins the SurfSense system prompt (provider variant,
|
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||||
citation rules, tool catalog, KB tree, skills metadata) into the cache.
|
request (provider variant, citation rules, tool catalog, KB tree,
|
||||||
|
skills metadata). The langchain agent factory always prepends
|
||||||
|
``request.system_message`` at index 0 (see ``factory.py``
|
||||||
|
``_execute_model_async``), so this targets exactly the main system
|
||||||
|
prompt regardless of how many other ``SystemMessage``\ s the
|
||||||
|
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||||
|
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||||
|
``role: system`` here would apply ``cache_control`` to **every**
|
||||||
|
system-role message and trip Anthropic's hard cap of 4 cache
|
||||||
|
breakpoints per request once the conversation accumulates enough
|
||||||
|
injected system messages — which surfaces as the upstream 400
|
||||||
|
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||||
|
via OpenRouter→Anthropic.
|
||||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||||
N+1 still reads turn N's cache up to the shared prefix.
|
N+1 still reads turn N's cache up to the shared prefix.
|
||||||
|
|
@ -51,11 +63,21 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Two-breakpoint policy: system + latest message. See module docstring for
|
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||||
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
|
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||||
# use 2 here, leaving headroom for Phase-2 tool caching.
|
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
|
#
|
||||||
|
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||||
|
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
||||||
|
# anonymous-doc) insert ``SystemMessage`` instances into
|
||||||
|
# ``state["messages"]`` that accumulate across turns. With
|
||||||
|
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
||||||
|
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
||||||
|
# always targets the langchain-prepended ``request.system_message``
|
||||||
|
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
||||||
|
# block), giving us exactly one stable cache breakpoint.
|
||||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
{"location": "message", "role": "system"},
|
{"location": "message", "index": 0},
|
||||||
{"location": "message", "index": -1},
|
{"location": "message", "index": -1},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_confluence_page(
|
async def create_confluence_page(
|
||||||
title: str,
|
title: str,
|
||||||
|
|
@ -42,13 +60,14 @@ def create_create_confluence_page_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_confluence_page called: title='{title}'")
|
logger.info(f"create_confluence_page called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
|
|
@ -183,7 +202,9 @@ def create_create_confluence_page_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_confluence_page(
|
async def delete_confluence_page(
|
||||||
page_title_or_id: str,
|
page_title_or_id: str,
|
||||||
|
|
@ -43,13 +61,14 @@ def create_delete_confluence_page_tool(
|
||||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_deletion_context(
|
context = await metadata_service.get_deletion_context(
|
||||||
search_space_id, user_id, page_title_or_id
|
search_space_id, user_id, page_title_or_id
|
||||||
|
|
@ -95,7 +114,9 @@ def create_delete_confluence_page_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
@ -135,7 +156,10 @@ def create_delete_confluence_page_tool(
|
||||||
or "status code 403" in str(api_err).lower()
|
or "status code 403" in str(api_err).lower()
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
flag_modified(connector, "config")
|
flag_modified(connector, "config")
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_confluence_page(
|
async def update_confluence_page(
|
||||||
page_title_or_id: str,
|
page_title_or_id: str,
|
||||||
|
|
@ -45,13 +63,14 @@ def create_update_confluence_page_tool(
|
||||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, page_title_or_id
|
search_space_id, user_id, page_title_or_id
|
||||||
|
|
@ -152,7 +171,10 @@ def create_update_confluence_page_tool(
|
||||||
or "status code 403" in str(api_err).lower()
|
or "status code 403" in str(api_err).lower()
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
flag_modified(connector, "config")
|
flag_modified(connector, "config")
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
|
"""Factory function to create the get_connected_accounts tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to scope account discovery to.
|
||||||
|
user_id: User ID to scope account discovery to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured StructuredTool for connected-accounts discovery.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
async def _run(service: str) -> list[dict[str, Any]]:
|
async def _run(service: str) -> list[dict[str, Any]]:
|
||||||
svc_cfg = MCP_SERVICES.get(service)
|
svc_cfg = MCP_SERVICES.get(service)
|
||||||
|
|
@ -68,6 +85,7 @@ def create_get_connected_accounts_tool(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||||
|
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_discord_channels tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_discord_channels tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_discord_channels() -> dict[str, Any]:
|
async def list_discord_channels() -> dict[str, Any]:
|
||||||
"""List text channels in the connected Discord server.
|
"""List text channels in the connected Discord server.
|
||||||
|
|
@ -22,13 +41,14 @@ def create_list_discord_channels_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with status and a list of channels (id, name).
|
Dictionary with status and a list of channels (id, name).
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
connector = await get_discord_connector(
|
connector = await get_discord_connector(
|
||||||
db_session, search_space_id, user_id
|
db_session, search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_discord_messages tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_discord_messages tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_discord_messages(
|
async def read_discord_messages(
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
|
|
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
|
||||||
Dictionary with status and a list of messages including
|
Dictionary with status and a list of messages including
|
||||||
id, author, content, timestamp.
|
id, author, content, timestamp.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
|
|
@ -39,6 +58,7 @@ def create_read_discord_messages_tool(
|
||||||
limit = min(limit, 50)
|
limit = min(limit, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
connector = await get_discord_connector(
|
connector = await get_discord_connector(
|
||||||
db_session, search_space_id, user_id
|
db_session, search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_discord_message tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_discord_message tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_discord_message(
|
async def send_discord_message(
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
|
|
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
|
|
@ -47,6 +65,7 @@ def create_send_discord_message_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
connector = await get_discord_connector(
|
connector = await get_discord_connector(
|
||||||
db_session, search_space_id, user_id
|
db_session, search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.dropbox.client import DropboxClient
|
from app.connectors.dropbox.client import DropboxClient
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_dropbox_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_dropbox_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_dropbox_file(
|
async def create_dropbox_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -82,13 +99,14 @@ def create_create_dropbox_file_tool(
|
||||||
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Dropbox tool not properly configured.",
|
"message": "Dropbox tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
@ -149,7 +167,9 @@ def create_create_dropbox_file_tool(
|
||||||
]
|
]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Error fetching folders for connector %s", cid, exc_info=True
|
"Error fetching folders for connector %s",
|
||||||
|
cid,
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
parent_folders[cid] = []
|
parent_folders[cid] = []
|
||||||
|
|
||||||
|
|
@ -217,7 +237,9 @@ def create_create_dropbox_file_tool(
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_file_type == "paper":
|
if final_file_type == "paper":
|
||||||
created = await client.create_paper_doc(file_path, final_content or "")
|
created = await client.create_paper_doc(
|
||||||
|
file_path, final_content or ""
|
||||||
|
)
|
||||||
file_id = created.get("file_id", "")
|
file_id = created.get("file_id", "")
|
||||||
web_url = created.get("url", "")
|
web_url = created.get("url", "")
|
||||||
else:
|
else:
|
||||||
|
|
@ -246,7 +268,9 @@ def create_create_dropbox_file_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
||||||
DocumentType,
|
DocumentType,
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
|
async_session_maker,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_dropbox_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_dropbox_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_dropbox_file(
|
async def delete_dropbox_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -55,13 +73,14 @@ def create_delete_dropbox_file_tool(
|
||||||
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Dropbox tool not properly configured.",
|
"message": "Dropbox tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
doc_result = await db_session.execute(
|
doc_result = await db_session.execute(
|
||||||
select(Document)
|
select(Document)
|
||||||
.join(
|
.join(
|
||||||
|
|
@ -193,14 +212,17 @@ def create_delete_dropbox_file_tool(
|
||||||
|
|
||||||
final_file_path = result.params.get("file_path", file_path)
|
final_file_path = result.params.get("file_path", file_path)
|
||||||
final_connector_id = result.params.get("connector_id", connector.id)
|
final_connector_id = result.params.get("connector_id", connector.id)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
if final_connector_id != connector.id:
|
if final_connector_id != connector.id:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
and_(
|
and_(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
SearchSourceConnector.id == final_connector_id,
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id
|
||||||
|
== search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||||
|
|
@ -221,7 +243,9 @@ def create_delete_dropbox_file_tool(
|
||||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
|
client = DropboxClient(
|
||||||
|
session=db_session, connector_id=actual_connector_id
|
||||||
|
)
|
||||||
await client.delete_file(final_file_path)
|
await client.delete_file(final_file_path)
|
||||||
|
|
||||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
is_image_gen_auto_mode,
|
is_image_gen_auto_mode,
|
||||||
)
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.utils.signed_image_urls import generate_image_token
|
from app.utils.signed_image_urls import generate_image_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||||
|
if custom_provider:
|
||||||
|
return custom_provider
|
||||||
|
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||||
|
|
||||||
|
|
||||||
def _build_model_string(
|
def _build_model_string(
|
||||||
provider: str, model_name: str, custom_provider: str | None
|
provider: str, model_name: str, custom_provider: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
if custom_provider:
|
prefix = _resolve_provider_prefix(provider, custom_provider)
|
||||||
return f"{custom_provider}/{model_name}"
|
|
||||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
||||||
return f"{prefix}/{model_name}"
|
return f"{prefix}/{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -146,14 +151,18 @@ def create_generate_image_tool(
|
||||||
"error": f"Image generation config {config_id} not found"
|
"error": f"Image generation config {config_id} not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
cfg.get("provider", ""),
|
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||||
cfg["model_name"],
|
|
||||||
cfg.get("custom_provider"),
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||||
if cfg.get("api_base"):
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = cfg["api_base"]
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=cfg.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if cfg.get("api_version"):
|
if cfg.get("api_version"):
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
gen_kwargs["api_version"] = cfg["api_version"]
|
||||||
if cfg.get("litellm_params"):
|
if cfg.get("litellm_params"):
|
||||||
|
|
@ -175,14 +184,18 @@ def create_generate_image_tool(
|
||||||
"error": f"Image generation config {config_id} not found"
|
"error": f"Image generation config {config_id} not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
db_cfg.provider.value,
|
db_cfg.provider.value, db_cfg.custom_provider
|
||||||
db_cfg.model_name,
|
|
||||||
db_cfg.custom_provider,
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
gen_kwargs["api_key"] = db_cfg.api_key
|
||||||
if db_cfg.api_base:
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = db_cfg.api_base
|
provider=db_cfg.provider.value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=db_cfg.api_base,
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if db_cfg.api_version:
|
if db_cfg.api_version:
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
gen_kwargs["api_version"] = db_cfg.api_version
|
||||||
if db_cfg.litellm_params:
|
if db_cfg.litellm_params:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
|
||||||
|
def split_recipients(value: str | None) -> list[str]:
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_composio_data(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner)
|
||||||
|
return inner
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_composio_gmail_tool(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Any, str | None]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return None, "Composio connected account ID not found for this Gmail connector."
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Gmail error")
|
||||||
|
|
||||||
|
return unwrap_composio_data(result.get("data")), None
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_gmail_draft tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_gmail_draft tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_gmail_draft(
|
async def create_gmail_draft(
|
||||||
to: str,
|
to: str,
|
||||||
|
|
@ -57,20 +75,23 @@ def create_create_gmail_draft_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
accounts = context.get("accounts", [])
|
||||||
|
|
@ -157,16 +178,13 @@ def create_create_gmail_draft_tool(
|
||||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -186,13 +204,17 @@ def create_create_gmail_draft_tool(
|
||||||
config_data["token"]
|
config_data["token"]
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
if config_data.get("refresh_token"):
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["refresh_token"]
|
config_data["refresh_token"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if config_data.get("client_secret"):
|
if config_data.get("client_secret"):
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["client_secret"]
|
config_data["client_secret"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
exp = config_data.get("expiry", "")
|
||||||
if exp:
|
if exp:
|
||||||
|
|
@ -208,10 +230,6 @@ def create_create_gmail_draft_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
message = MIMEText(final_body)
|
||||||
message["to"] = final_to
|
message["to"] = final_to
|
||||||
message["subject"] = final_subject
|
message["subject"] = final_subject
|
||||||
|
|
@ -222,6 +240,34 @@ def create_create_gmail_draft_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
|
)
|
||||||
|
|
||||||
|
created, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_CREATE_EMAIL_DRAFT",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -285,7 +331,9 @@ def create_create_gmail_draft_tool(
|
||||||
draft_id=created.get("id"),
|
draft_id=created.get("id"),
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||||
"""Read the full content of a specific Gmail email by its message ID.
|
"""Read the full content of a specific Gmail email by its message ID.
|
||||||
|
|
@ -32,10 +49,11 @@ def create_read_gmail_email_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with status and the full email content formatted as markdown.
|
Dictionary with status and the full email content formatted as markdown.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
@ -50,7 +68,57 @@ def create_read_gmail_email_tool(
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
_format_gmail_summary,
|
||||||
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
detail, error = await service.get_gmail_message_detail(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
if not detail:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"message": f"Email with ID '{message_id}' not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
summary = _format_gmail_summary(detail)
|
||||||
|
content = (
|
||||||
|
f"# {summary['subject']}\n\n"
|
||||||
|
f"**From:** {summary['from']}\n"
|
||||||
|
f"**To:** {summary['to']}\n"
|
||||||
|
f"**Date:** {summary['date']}\n\n"
|
||||||
|
f"## Message Content\n\n"
|
||||||
|
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||||
|
f"## Message Details\n\n"
|
||||||
|
f"- **Message ID:** {summary['message_id']}\n"
|
||||||
|
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": summary["message_id"] or message_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
_build_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
|
@ -84,7 +152,11 @@ def create_read_gmail_email_tool(
|
||||||
|
|
||||||
content = gmail.format_message_to_markdown(detail)
|
content = gmail.format_message_to_markdown(detail)
|
||||||
|
|
||||||
return {"status": "success", "message_id": message_id, "content": content}
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": message_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
raise ValueError("Composio connectors must use Composio tool execution.")
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if not cca_id:
|
|
||||||
raise ValueError("Composio connected account ID not found.")
|
|
||||||
return build_composio_credentials(cca_id)
|
|
||||||
|
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
|
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
|
||||||
|
headers = message.get("payload", {}).get("headers", [])
|
||||||
|
return {
|
||||||
|
header.get("name", "").lower(): header.get("value", "")
|
||||||
|
for header in headers
|
||||||
|
if isinstance(header, dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
headers = _gmail_headers(message)
|
||||||
|
return {
|
||||||
|
"message_id": message.get("id") or message.get("messageId"),
|
||||||
|
"thread_id": message.get("threadId"),
|
||||||
|
"subject": message.get("subject") or headers.get("subject", "No Subject"),
|
||||||
|
"from": message.get("sender") or headers.get("from", "Unknown"),
|
||||||
|
"to": message.get("to") or headers.get("to", ""),
|
||||||
|
"date": message.get("messageTimestamp") or headers.get("date", ""),
|
||||||
|
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
|
||||||
|
"labels": message.get("labelIds", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _search_composio_gmail(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
query: str,
|
||||||
|
max_results: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
messages, _next_token, _estimate, error = await service.get_gmail_messages(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
query=query,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
emails = [_format_gmail_summary(message) for message in messages]
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"emails": emails,
|
||||||
|
"total": len(emails),
|
||||||
|
"message": "No emails found." if not emails else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_search_gmail_tool(
|
def create_search_gmail_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the search_gmail tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured search_gmail tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_gmail(
|
async def search_gmail(
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -90,12 +159,13 @@ def create_search_gmail_tool(
|
||||||
Dictionary with status and a list of email summaries including
|
Dictionary with status and a list of email summaries including
|
||||||
message_id, subject, from, date, snippet.
|
message_id, subject, from, date, snippet.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
max_results = min(max_results, 20)
|
max_results = min(max_results, 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
@ -110,6 +180,14 @@ def create_search_gmail_tool(
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
):
|
||||||
|
return await _search_composio_gmail(
|
||||||
|
connector, str(user_id), query, max_results
|
||||||
|
)
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_gmail_email(
|
async def send_gmail_email(
|
||||||
to: str,
|
to: str,
|
||||||
|
|
@ -58,20 +76,23 @@ def create_send_gmail_email_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
accounts = context.get("accounts", [])
|
||||||
|
|
@ -158,16 +179,13 @@ def create_send_gmail_email_tool(
|
||||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -187,13 +205,17 @@ def create_send_gmail_email_tool(
|
||||||
config_data["token"]
|
config_data["token"]
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
if config_data.get("refresh_token"):
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["refresh_token"]
|
config_data["refresh_token"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if config_data.get("client_secret"):
|
if config_data.get("client_secret"):
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["client_secret"]
|
config_data["client_secret"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
exp = config_data.get("expiry", "")
|
||||||
if exp:
|
if exp:
|
||||||
|
|
@ -209,10 +231,6 @@ def create_send_gmail_email_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
message = MIMEText(final_body)
|
||||||
message["to"] = final_to
|
message["to"] = final_to
|
||||||
message["subject"] = final_subject
|
message["subject"] = final_subject
|
||||||
|
|
@ -223,6 +241,34 @@ def create_send_gmail_email_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_SEND_EMAIL",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(sent, dict):
|
||||||
|
sent = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
sent = await asyncio.get_event_loop().run_in_executor(
|
sent = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -286,7 +332,9 @@ def create_send_gmail_email_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the trash_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured trash_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def trash_gmail_email(
|
async def trash_gmail_email(
|
||||||
email_subject_or_id: str,
|
email_subject_or_id: str,
|
||||||
|
|
@ -55,13 +73,14 @@ def create_trash_gmail_email_tool(
|
||||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_trash_context(
|
context = await metadata_service.get_trash_context(
|
||||||
search_space_id, user_id, email_subject_or_id
|
search_space_id, user_id, email_subject_or_id
|
||||||
|
|
@ -122,7 +141,9 @@ def create_trash_gmail_email_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
if not final_connector_id:
|
if not final_connector_id:
|
||||||
return {
|
return {
|
||||||
|
|
@ -158,16 +179,13 @@ def create_trash_gmail_email_tool(
|
||||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -187,13 +205,17 @@ def create_trash_gmail_email_tool(
|
||||||
config_data["token"]
|
config_data["token"]
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
if config_data.get("refresh_token"):
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["refresh_token"]
|
config_data["refresh_token"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if config_data.get("client_secret"):
|
if config_data.get("client_secret"):
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["client_secret"]
|
config_data["client_secret"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
exp = config_data.get("expiry", "")
|
||||||
if exp:
|
if exp:
|
||||||
|
|
@ -209,11 +231,24 @@ def create_trash_gmail_email_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
_trashed, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_MOVE_TO_TRASH",
|
||||||
|
{"user_id": "me", "message_id": final_message_id},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_gmail_draft tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_gmail_draft tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_gmail_draft(
|
async def update_gmail_draft(
|
||||||
draft_subject_or_id: str,
|
draft_subject_or_id: str,
|
||||||
|
|
@ -76,13 +94,14 @@ def create_update_gmail_draft_tool(
|
||||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, draft_subject_or_id
|
search_space_id, user_id, draft_subject_or_id
|
||||||
|
|
@ -188,16 +207,13 @@ def create_update_gmail_draft_tool(
|
||||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_gmail = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_gmail:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
|
@ -217,13 +233,17 @@ def create_update_gmail_draft_tool(
|
||||||
config_data["token"]
|
config_data["token"]
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
if config_data.get("refresh_token"):
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["refresh_token"]
|
config_data["refresh_token"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if config_data.get("client_secret"):
|
if config_data.get("client_secret"):
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
config_data["client_secret"]
|
config_data["client_secret"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
exp = config_data.get("expiry", "")
|
||||||
if exp:
|
if exp:
|
||||||
|
|
@ -239,15 +259,19 @@ def create_update_gmail_draft_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
# Resolve draft_id if not already available
|
# Resolve draft_id if not already available
|
||||||
if not final_draft_id:
|
if not final_draft_id:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||||
)
|
)
|
||||||
|
if is_composio_gmail:
|
||||||
|
final_draft_id = await _find_composio_draft_id_by_message(
|
||||||
|
connector, user_id, message_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
final_draft_id = await _find_draft_id_by_message(
|
final_draft_id = await _find_draft_id_by_message(
|
||||||
gmail_service, message_id
|
gmail_service, message_id
|
||||||
)
|
)
|
||||||
|
|
@ -272,6 +296,35 @@ def create_update_gmail_draft_tool(
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_UPDATE_DRAFT",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"draft_id": final_draft_id,
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(updated, dict):
|
||||||
|
updated = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_composio_draft_id_by_message(
|
||||||
|
connector: Any, user_id: str, message_id: str
|
||||||
|
) -> str | None:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
page_token = ""
|
||||||
|
while True:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"user_id": "me",
|
||||||
|
"max_results": 100,
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
data, error = await execute_composio_gmail_tool(
|
||||||
|
connector, user_id, "GMAIL_LIST_DRAFTS", params
|
||||||
|
)
|
||||||
|
if error or not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for draft in data.get("drafts", []):
|
||||||
|
if draft.get("message", {}).get("id") == message_id:
|
||||||
|
return draft.get("id")
|
||||||
|
|
||||||
|
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||||
|
if not page_token:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_calendar_event(
|
async def create_calendar_event(
|
||||||
summary: str,
|
summary: str,
|
||||||
|
|
@ -60,20 +78,23 @@ def create_create_calendar_event_tool(
|
||||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
accounts = context.get("accounts", [])
|
||||||
|
|
@ -113,7 +134,9 @@ def create_create_calendar_event_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
final_summary = result.params.get("summary", summary)
|
final_summary = result.params.get("summary", summary)
|
||||||
final_start_datetime = result.params.get("start_datetime", start_datetime)
|
final_start_datetime = result.params.get(
|
||||||
|
"start_datetime", start_datetime
|
||||||
|
)
|
||||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||||
final_description = result.params.get("description", description)
|
final_description = result.params.get("description", description)
|
||||||
final_location = result.params.get("location", location)
|
final_location = result.params.get("location", location)
|
||||||
|
|
@ -121,7 +144,10 @@ def create_create_calendar_event_tool(
|
||||||
final_connector_id = result.params.get("connector_id")
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
|
||||||
if not final_summary or not final_summary.strip():
|
if not final_summary or not final_summary.strip():
|
||||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Event summary cannot be empty.",
|
||||||
|
}
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
@ -168,16 +194,13 @@ def create_create_calendar_event_tool(
|
||||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -211,10 +234,6 @@ def create_create_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
tz = context.get("timezone", "UTC")
|
tz = context.get("timezone", "UTC")
|
||||||
event_body: dict[str, Any] = {
|
event_body: dict[str, Any] = {
|
||||||
"summary": final_summary,
|
"summary": final_summary,
|
||||||
|
|
@ -231,6 +250,43 @@ def create_create_calendar_event_tool(
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_params = {
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"summary": final_summary,
|
||||||
|
"start_datetime": final_start_datetime,
|
||||||
|
"end_datetime": final_end_datetime,
|
||||||
|
"timezone": tz,
|
||||||
|
"attendees": final_attendees or [],
|
||||||
|
}
|
||||||
|
if final_description:
|
||||||
|
composio_params["description"] = final_description
|
||||||
|
if final_location:
|
||||||
|
composio_params["location"] = final_location
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
created = composio_result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -295,7 +351,9 @@ def create_create_calendar_event_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_calendar_event(
|
async def delete_calendar_event(
|
||||||
event_title_or_id: str,
|
event_title_or_id: str,
|
||||||
|
|
@ -54,13 +72,14 @@ def create_delete_calendar_event_tool(
|
||||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_deletion_context(
|
context = await metadata_service.get_deletion_context(
|
||||||
search_space_id, user_id, event_title_or_id
|
search_space_id, user_id, event_title_or_id
|
||||||
|
|
@ -121,7 +140,9 @@ def create_delete_calendar_event_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
if not final_connector_id:
|
if not final_connector_id:
|
||||||
return {
|
return {
|
||||||
|
|
@ -159,16 +180,13 @@ def create_delete_calendar_event_tool(
|
||||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -202,11 +220,29 @@ def create_delete_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||||
|
params={
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||||
|
if "T" in value:
|
||||||
|
return value
|
||||||
|
time = "23:59:59" if is_end else "00:00:00"
|
||||||
|
return f"{value}T{time}Z"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
events = []
|
||||||
|
for ev in events_raw:
|
||||||
|
start = ev.get("start", {})
|
||||||
|
end = ev.get("end", {})
|
||||||
|
attendees_raw = ev.get("attendees", [])
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"event_id": ev.get("id"),
|
||||||
|
"summary": ev.get("summary", "No Title"),
|
||||||
|
"start": start.get("dateTime") or start.get("date", ""),
|
||||||
|
"end": end.get("dateTime") or end.get("date", ""),
|
||||||
|
"location": ev.get("location", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"html_link": ev.get("htmlLink", ""),
|
||||||
|
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||||
|
"status": ev.get("status", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
def create_search_calendar_events_tool(
|
def create_search_calendar_events_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the search_calendar_events tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured search_calendar_events tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_calendar_events(
|
async def search_calendar_events(
|
||||||
start_date: str,
|
start_date: str,
|
||||||
|
|
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
|
||||||
Dictionary with status and a list of events including
|
Dictionary with status and a list of events including
|
||||||
event_id, summary, start, end, location, attendees.
|
event_id, summary, start, end, location, attendees.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Calendar tool not properly configured.",
|
"message": "Calendar tool not properly configured.",
|
||||||
|
|
@ -47,6 +93,7 @@ def create_search_calendar_events_tool(
|
||||||
max_results = min(max_results, 50)
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
@ -61,9 +108,34 @@ def create_search_calendar_events_tool(
|
||||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
events_raw, error = await ComposioService().get_calendar_events(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||||
|
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if not events_raw and not error:
|
||||||
|
error = "No events found in the specified date range."
|
||||||
|
else:
|
||||||
creds = _build_credentials(connector)
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
from app.connectors.google_calendar_connector import (
|
||||||
|
GoogleCalendarConnector,
|
||||||
|
)
|
||||||
|
|
||||||
cal = GoogleCalendarConnector(
|
cal = GoogleCalendarConnector(
|
||||||
credentials=creds,
|
credentials=creds,
|
||||||
|
|
@ -97,24 +169,7 @@ def create_search_calendar_events_tool(
|
||||||
}
|
}
|
||||||
return {"status": "error", "message": error}
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
events = []
|
events = _format_calendar_events(events_raw)
|
||||||
for ev in events_raw:
|
|
||||||
start = ev.get("start", {})
|
|
||||||
end = ev.get("end", {})
|
|
||||||
attendees_raw = ev.get("attendees", [])
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"event_id": ev.get("id"),
|
|
||||||
"summary": ev.get("summary", "No Title"),
|
|
||||||
"start": start.get("dateTime") or start.get("date", ""),
|
|
||||||
"end": end.get("dateTime") or end.get("date", ""),
|
|
||||||
"location": ev.get("location", ""),
|
|
||||||
"description": ev.get("description", ""),
|
|
||||||
"html_link": ev.get("htmlLink", ""),
|
|
||||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
|
||||||
"status": ev.get("status", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "success", "events": events, "total": len(events)}
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_calendar_event(
|
async def update_calendar_event(
|
||||||
event_title_or_id: str,
|
event_title_or_id: str,
|
||||||
|
|
@ -74,13 +92,14 @@ def create_update_calendar_event_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, event_title_or_id
|
search_space_id, user_id, event_title_or_id
|
||||||
|
|
@ -192,16 +211,13 @@ def create_update_calendar_event_tool(
|
||||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
is_composio_calendar = (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_calendar:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
|
@ -235,10 +251,6 @@ def create_update_calendar_event_tool(
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
update_body: dict[str, Any] = {}
|
update_body: dict[str, Any] = {}
|
||||||
if final_new_summary is not None:
|
if final_new_summary is not None:
|
||||||
update_body["summary"] = final_new_summary
|
update_body["summary"] = final_new_summary
|
||||||
|
|
@ -247,7 +259,9 @@ def create_update_calendar_event_tool(
|
||||||
final_new_start_datetime, context
|
final_new_start_datetime, context
|
||||||
)
|
)
|
||||||
if final_new_end_datetime is not None:
|
if final_new_end_datetime is not None:
|
||||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
update_body["end"] = _build_time_body(
|
||||||
|
final_new_end_datetime, context
|
||||||
|
)
|
||||||
if final_new_description is not None:
|
if final_new_description is not None:
|
||||||
update_body["description"] = final_new_description
|
update_body["description"] = final_new_description
|
||||||
if final_new_location is not None:
|
if final_new_location is not None:
|
||||||
|
|
@ -264,6 +278,53 @@ def create_update_calendar_event_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_params: dict[str, Any] = {
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
}
|
||||||
|
if final_new_summary is not None:
|
||||||
|
composio_params["summary"] = final_new_summary
|
||||||
|
if final_new_start_datetime is not None:
|
||||||
|
composio_params["start_time"] = final_new_start_datetime
|
||||||
|
if final_new_end_datetime is not None:
|
||||||
|
composio_params["end_time"] = final_new_end_datetime
|
||||||
|
if final_new_description is not None:
|
||||||
|
composio_params["description"] = final_new_description
|
||||||
|
if final_new_location is not None:
|
||||||
|
composio_params["location"] = final_new_location
|
||||||
|
if final_new_attendees is not None:
|
||||||
|
composio_params["attendees"] = [
|
||||||
|
e.strip() for e in final_new_attendees if e.strip()
|
||||||
|
]
|
||||||
|
if not _is_date_only(
|
||||||
|
final_new_start_datetime or final_new_end_datetime or ""
|
||||||
|
):
|
||||||
|
composio_params["timezone"] = context.get("timezone", "UTC")
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
updated = composio_result.get("data", {})
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
updated = updated.get("data", updated)
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
updated = updated.get("response_data", updated)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -314,7 +375,9 @@ def create_update_calendar_event_tool(
|
||||||
kb_message_suffix = ""
|
kb_message_suffix = ""
|
||||||
if document_id is not None:
|
if document_id is not None:
|
||||||
try:
|
try:
|
||||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
from app.services.google_calendar import (
|
||||||
|
GoogleCalendarKBSyncService,
|
||||||
|
)
|
||||||
|
|
||||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||||
kb_result = await kb_service.sync_after_update(
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_google_drive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Google Drive connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_google_drive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_google_drive_file(
|
async def create_google_drive_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
|
||||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||||
|
|
@ -78,18 +98,23 @@ def create_create_google_drive_file_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
accounts = context.get("accounts", [])
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
logger.warning("All Google Drive accounts have expired authentication")
|
logger.warning(
|
||||||
|
"All Google Drive accounts have expired authentication"
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "auth_error",
|
||||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
|
@ -179,23 +204,53 @@ def create_create_google_drive_file_tool(
|
||||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_creds = None
|
is_composio_drive = (
|
||||||
if (
|
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_drive:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=db_session,
|
session=db_session,
|
||||||
connector_id=actual_connector_id,
|
connector_id=actual_connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
if is_composio_drive:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"name": final_name,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"fields": "id,name,webViewLink,mimeType",
|
||||||
|
}
|
||||||
|
if final_parent_folder_id:
|
||||||
|
params["parents"] = [final_parent_folder_id]
|
||||||
|
if final_content:
|
||||||
|
params["description"] = final_content[:4096]
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_CREATE_FILE",
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
created = result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
created = await client.create_file(
|
created = await client.create_file(
|
||||||
name=final_name,
|
name=final_name,
|
||||||
mime_type=mime_type,
|
mime_type=mime_type,
|
||||||
|
|
@ -253,7 +308,9 @@ def create_create_google_drive_file_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_google_drive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Google Drive connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_google_drive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_google_drive_file(
|
async def delete_google_drive_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -55,13 +75,14 @@ def create_delete_google_drive_file_tool(
|
||||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_trash_context(
|
context = await metadata_service.get_trash_context(
|
||||||
search_space_id, user_id, file_name
|
search_space_id, user_id, file_name
|
||||||
|
|
@ -122,7 +143,9 @@ def create_delete_google_drive_file_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
if not final_connector_id:
|
if not final_connector_id:
|
||||||
return {
|
return {
|
||||||
|
|
@ -158,23 +181,37 @@ def create_delete_google_drive_file_tool(
|
||||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_creds = None
|
is_composio_drive = (
|
||||||
if (
|
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
):
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
if is_composio_drive:
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=db_session,
|
session=db_session,
|
||||||
connector_id=connector.id,
|
connector_id=connector.id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
if is_composio_drive:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||||
|
params={"file_id": final_file_id},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
else:
|
||||||
await client.trash_file(file_id=final_file_id)
|
await client.trash_file(file_id=final_file_id)
|
||||||
except HttpError as http_err:
|
except HttpError as http_err:
|
||||||
if http_err.resp.status == 403:
|
if http_err.resp.status == 403:
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
"create_gmail_draft",
|
"create_gmail_draft",
|
||||||
"update_gmail_draft",
|
"update_gmail_draft",
|
||||||
|
"create_calendar_event",
|
||||||
"create_notion_page",
|
"create_notion_page",
|
||||||
"create_confluence_page",
|
"create_confluence_page",
|
||||||
"create_google_drive_file",
|
"create_google_drive_file",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the create_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||||
|
keep the request's outer transaction free of long-running Jira API
|
||||||
|
blocking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_jira_issue(
|
async def create_jira_issue(
|
||||||
project_key: str,
|
project_key: str,
|
||||||
|
|
@ -49,10 +72,11 @@ def create_create_jira_issue_tool(
|
||||||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
|
|
@ -97,7 +121,10 @@ def create_create_jira_issue_tool(
|
||||||
final_connector_id = result.params.get("connector_id", connector_id)
|
final_connector_id = result.params.get("connector_id", connector_id)
|
||||||
|
|
||||||
if not final_summary or not final_summary.strip():
|
if not final_summary or not final_summary.strip():
|
||||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Issue summary cannot be empty.",
|
||||||
|
}
|
||||||
if not final_project_key:
|
if not final_project_key:
|
||||||
return {"status": "error", "message": "A project must be selected."}
|
return {"status": "error", "message": "A project must be selected."}
|
||||||
|
|
||||||
|
|
@ -117,7 +144,10 @@ def create_create_jira_issue_tool(
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connector = result.scalars().first()
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Jira connector found."}
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Jira connector found.",
|
||||||
|
}
|
||||||
actual_connector_id = connector.id
|
actual_connector_id = connector.id
|
||||||
else:
|
else:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
|
|
@ -188,7 +218,9 @@ def create_create_jira_issue_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the delete_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_jira_issue(
|
async def delete_jira_issue(
|
||||||
issue_title_or_key: str,
|
issue_title_or_key: str,
|
||||||
|
|
@ -44,10 +65,11 @@ def create_delete_jira_issue_tool(
|
||||||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_deletion_context(
|
context = await metadata_service.get_deletion_context(
|
||||||
search_space_id, user_id, issue_title_or_key
|
search_space_id, user_id, issue_title_or_key
|
||||||
|
|
@ -92,7 +114,9 @@ def create_delete_jira_issue_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
@ -129,7 +153,10 @@ def create_delete_jira_issue_tool(
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
if "status code 403" in str(api_err).lower():
|
if "status code 403" in str(api_err).lower():
|
||||||
try:
|
try:
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
flag_modified(connector, "config")
|
flag_modified(connector, "config")
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the update_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_jira_issue(
|
async def update_jira_issue(
|
||||||
issue_title_or_key: str,
|
issue_title_or_key: str,
|
||||||
|
|
@ -48,10 +69,11 @@ def create_update_jira_issue_tool(
|
||||||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, issue_title_or_key
|
search_space_id, user_id, issue_title_or_key
|
||||||
|
|
@ -97,7 +119,9 @@ def create_update_jira_issue_tool(
|
||||||
|
|
||||||
final_issue_key = result.params.get("issue_key", issue_key)
|
final_issue_key = result.params.get("issue_key", issue_key)
|
||||||
final_summary = result.params.get("new_summary", new_summary)
|
final_summary = result.params.get("new_summary", new_summary)
|
||||||
final_description = result.params.get("new_description", new_description)
|
final_description = result.params.get(
|
||||||
|
"new_description", new_description
|
||||||
|
)
|
||||||
final_priority = result.params.get("new_priority", new_priority)
|
final_priority = result.params.get("new_priority", new_priority)
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
|
|
@ -140,7 +164,9 @@ def create_update_jira_issue_tool(
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "paragraph",
|
"type": "paragraph",
|
||||||
"content": [{"type": "text", "text": final_description}],
|
"content": [
|
||||||
|
{"type": "text", "text": final_description}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
@ -161,7 +187,10 @@ def create_update_jira_issue_tool(
|
||||||
except Exception as api_err:
|
except Exception as api_err:
|
||||||
if "status code 403" in str(api_err).lower():
|
if "status code 403" in str(api_err).lower():
|
||||||
try:
|
try:
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
flag_modified(connector, "config")
|
flag_modified(connector, "config")
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearToolMetadataService
|
from app.services.linear import LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the create_linear_issue tool.
|
||||||
Factory function to create the create_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured create_linear_issue tool
|
Configured create_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_linear_issue(
|
async def create_linear_issue(
|
||||||
|
|
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_linear_issue called: title='{title}'")
|
logger.info(f"create_linear_issue called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -75,13 +83,16 @@ def create_create_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
workspaces = context.get("workspaces", [])
|
workspaces = context.get("workspaces", [])
|
||||||
|
|
@ -128,7 +139,10 @@ def create_create_linear_issue_tool(
|
||||||
|
|
||||||
if not final_title or not final_title.strip():
|
if not final_title or not final_title.strip():
|
||||||
logger.error("Title is empty or contains only whitespace")
|
logger.error("Title is empty or contains only whitespace")
|
||||||
return {"status": "error", "message": "Issue title cannot be empty."}
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Issue title cannot be empty.",
|
||||||
|
}
|
||||||
if not final_team_id:
|
if not final_team_id:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
|
|
@ -192,7 +206,9 @@ def create_create_linear_issue_tool(
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("status") == "error":
|
if result.get("status") == "error":
|
||||||
logger.error(f"Failed to create Linear issue: {result.get('message')}")
|
logger.error(
|
||||||
|
f"Failed to create Linear issue: {result.get('message')}"
|
||||||
|
)
|
||||||
return {"status": "error", "message": result.get("message")}
|
return {"status": "error", "message": result.get("message")}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -215,7 +231,9 @@ def create_create_linear_issue_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearToolMetadataService
|
from app.services.linear import LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the delete_linear_issue tool.
|
||||||
Factory function to create the delete_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for finding the correct Linear connector
|
user_id: User ID for finding the correct Linear connector
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured delete_linear_issue tool
|
Configured delete_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_linear_issue(
|
async def delete_linear_issue(
|
||||||
|
|
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
|
||||||
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -83,6 +91,7 @@ def create_delete_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_delete_context(
|
context = await metadata_service.get_delete_context(
|
||||||
search_space_id, user_id, issue_ref
|
search_space_id, user_id, issue_ref
|
||||||
|
|
@ -136,7 +145,9 @@ def create_delete_linear_issue_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the update_linear_issue tool.
|
||||||
Factory function to create the update_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured update_linear_issue tool
|
Configured update_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_linear_issue(
|
async def update_linear_issue(
|
||||||
|
|
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
|
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -96,6 +104,7 @@ def create_update_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, issue_ref
|
search_space_id, user_id, issue_ref
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_luma_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_luma_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_luma_event(
|
async def create_luma_event(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -40,11 +58,14 @@ def create_create_luma_event_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_luma_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_luma_events tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_luma_events tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_luma_events(
|
async def list_luma_events(
|
||||||
max_results: int = 25,
|
max_results: int = 25,
|
||||||
|
|
@ -28,13 +47,16 @@ def create_list_luma_events_tool(
|
||||||
Dictionary with status and a list of events including
|
Dictionary with status and a list of events including
|
||||||
event_id, name, start_at, end_at, location, url.
|
event_id, name, start_at, end_at, location, url.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
max_results = min(max_results, 50)
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_luma_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_luma_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_luma_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||||
"""Read detailed information about a specific Luma event.
|
"""Read detailed information about a specific Luma event.
|
||||||
|
|
@ -26,11 +45,14 @@ def create_read_luma_event_tool(
|
||||||
Dictionary with status and full event details including
|
Dictionary with status and full event details including
|
||||||
description, attendees count, meeting URL.
|
description, attendees count, meeting URL.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_luma_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion import NotionToolMetadataService
|
from app.services.notion import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the create_notion_page tool.
|
Factory function to create the create_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||||
|
keep the request's outer transaction free of long-running Notion API
|
||||||
|
blocking.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured create_notion_page tool
|
Configured create_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_notion_page(
|
async def create_notion_page(
|
||||||
|
|
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_notion_page called: title='{title}'")
|
logger.info(f"create_notion_page called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -77,13 +88,16 @@ def create_create_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_creation_context(
|
context = await metadata_service.get_creation_context(
|
||||||
search_space_id, user_id
|
search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": context["error"],
|
"message": context["error"],
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the delete_notion_page tool.
|
Factory function to create the delete_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for finding the correct Notion connector
|
user_id: User ID for finding the correct Notion connector
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured delete_notion_page tool
|
Configured delete_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_notion_page(
|
async def delete_notion_page(
|
||||||
|
|
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
|
||||||
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
|
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -73,6 +81,7 @@ def create_delete_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
# Get page context (page_id, account, title) from indexed data
|
# Get page context (page_id, account, title) from indexed data
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_delete_context(
|
context = await metadata_service.get_delete_context(
|
||||||
|
|
@ -136,7 +145,9 @@ def create_delete_notion_page_tool(
|
||||||
final_connector_id = result.params.get(
|
final_connector_id = result.params.get(
|
||||||
"connector_id", connector_id_from_context
|
"connector_id", connector_id_from_context
|
||||||
)
|
)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion import NotionToolMetadataService
|
from app.services.notion import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the update_notion_page tool.
|
Factory function to create the update_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache (see
|
||||||
|
``create_create_notion_page_tool`` for the full rationale).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured update_notion_page tool
|
Configured update_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_notion_page(
|
async def update_notion_page(
|
||||||
|
|
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
|
||||||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -88,6 +96,7 @@ def create_update_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
context = await metadata_service.get_update_context(
|
context = await metadata_service.get_update_context(
|
||||||
search_space_id, user_id, page_title
|
search_space_id, user_id, page_title
|
||||||
|
|
@ -204,7 +213,9 @@ def create_update_notion_page_tool(
|
||||||
if result.get("status") == "success" and document_id is not None:
|
if result.get("status") == "success" and document_id is not None:
|
||||||
from app.services.notion import NotionKBSyncService
|
from app.services.notion import NotionKBSyncService
|
||||||
|
|
||||||
logger.info(f"Updating knowledge base for document {document_id}...")
|
logger.info(
|
||||||
|
f"Updating knowledge base for document {document_id}..."
|
||||||
|
)
|
||||||
kb_service = NotionKBSyncService(db_session)
|
kb_service = NotionKBSyncService(db_session)
|
||||||
kb_result = await kb_service.sync_after_update(
|
kb_result = await kb_service.sync_after_update(
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.onedrive.client import OneDriveClient
|
from app.connectors.onedrive.client import OneDriveClient
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_onedrive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_onedrive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_onedrive_file(
|
async def create_onedrive_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -70,13 +87,14 @@ def create_create_onedrive_file_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_onedrive_file called: name='{name}'")
|
logger.info(f"create_onedrive_file called: name='{name}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "OneDrive tool not properly configured.",
|
"message": "OneDrive tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
|
@ -136,7 +154,9 @@ def create_create_onedrive_file_tool(
|
||||||
]
|
]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Error fetching folders for connector %s", cid, exc_info=True
|
"Error fetching folders for connector %s",
|
||||||
|
cid,
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
parent_folders[cid] = []
|
parent_folders[cid] = []
|
||||||
|
|
||||||
|
|
@ -223,7 +243,9 @@ def create_create_onedrive_file_tool(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
if kb_result["status"] == "success":
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
except Exception as kb_err:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
||||||
DocumentType,
|
DocumentType,
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
|
async_session_maker,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_onedrive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_onedrive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_onedrive_file(
|
async def delete_onedrive_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -56,13 +74,14 @@ def create_delete_onedrive_file_tool(
|
||||||
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "OneDrive tool not properly configured.",
|
"message": "OneDrive tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
doc_result = await db_session.execute(
|
doc_result = await db_session.execute(
|
||||||
select(Document)
|
select(Document)
|
||||||
.join(
|
.join(
|
||||||
|
|
@ -95,7 +114,9 @@ def create_delete_onedrive_file_tool(
|
||||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||||
func.lower(
|
func.lower(
|
||||||
cast(
|
cast(
|
||||||
Document.document_metadata["onedrive_file_name"],
|
Document.document_metadata[
|
||||||
|
"onedrive_file_name"
|
||||||
|
],
|
||||||
String,
|
String,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -193,14 +214,17 @@ def create_delete_onedrive_file_tool(
|
||||||
|
|
||||||
final_file_id = result.params.get("file_id", file_id)
|
final_file_id = result.params.get("file_id", file_id)
|
||||||
final_connector_id = result.params.get("connector_id", connector.id)
|
final_connector_id = result.params.get("connector_id", connector.id)
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
if final_connector_id != connector.id:
|
if final_connector_id != connector.id:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
and_(
|
and_(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
SearchSourceConnector.id == final_connector_id,
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id
|
||||||
|
== search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||||
|
|
|
||||||
|
|
@ -824,13 +824,22 @@ async def build_tools_async(
|
||||||
"""Async version of build_tools that also loads MCP tools from database.
|
"""Async version of build_tools that also loads MCP tools from database.
|
||||||
|
|
||||||
Design Note:
|
Design Note:
|
||||||
This function exists because MCP tools require database queries to load user configs,
|
This function exists because MCP tools require database queries to load
|
||||||
while built-in tools are created synchronously from static code.
|
user configs, while built-in tools are created synchronously from static
|
||||||
|
code.
|
||||||
|
|
||||||
Alternative: We could make build_tools() itself async and always query the database,
|
Alternative: We could make build_tools() itself async and always query
|
||||||
but that would force async everywhere even when only using built-in tools. The current
|
the database, but that would force async everywhere even when only using
|
||||||
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
built-in tools. The current design keeps the simple case (static tools
|
||||||
database-loaded tools through this async wrapper.
|
only) synchronous while supporting dynamic database-loaded tools through
|
||||||
|
this async wrapper.
|
||||||
|
|
||||||
|
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
|
||||||
|
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
|
||||||
|
the event loop) are kicked off concurrently. Cold-path savings are
|
||||||
|
bounded by the slower of the two — typically MCP at ~200ms-1.7s —
|
||||||
|
so the parallelization recovers the ~50-200ms previously spent
|
||||||
|
serially on built-in construction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dependencies: Dict containing all possible dependencies
|
dependencies: Dict containing all possible dependencies
|
||||||
|
|
@ -843,33 +852,70 @@ async def build_tools_async(
|
||||||
List of configured tool instances ready for the agent, including MCP tools.
|
List of configured tool instances ready for the agent, including MCP tools.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
_perf_log = logging.getLogger("surfsense.perf")
|
_perf_log = logging.getLogger("surfsense.perf")
|
||||||
_perf_log.setLevel(logging.DEBUG)
|
_perf_log.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
can_load_mcp = (
|
||||||
|
include_mcp_tools
|
||||||
|
and "db_session" in dependencies
|
||||||
|
and "search_space_id" in dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
|
||||||
|
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
|
||||||
|
# function over its inputs — safe to thread-shift.
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
builtin_task = asyncio.create_task(
|
||||||
|
asyncio.to_thread(
|
||||||
|
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_task: asyncio.Task | None = None
|
||||||
|
if can_load_mcp:
|
||||||
|
mcp_task = asyncio.create_task(
|
||||||
|
load_mcp_tools(
|
||||||
|
dependencies["db_session"],
|
||||||
|
dependencies["search_space_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Surface failures from each task independently so a flaky MCP
|
||||||
|
# endpoint never poisons built-in tool registration. ``return_exceptions``
|
||||||
|
# gives us per-task exceptions instead of dropping the second result
|
||||||
|
# when the first raises.
|
||||||
|
if mcp_task is not None:
|
||||||
|
builtin_result, mcp_result = await asyncio.gather(
|
||||||
|
builtin_task, mcp_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
builtin_result = await builtin_task
|
||||||
|
mcp_result = None
|
||||||
|
|
||||||
|
if isinstance(builtin_result, BaseException):
|
||||||
|
raise builtin_result # built-in registration failure is non-recoverable
|
||||||
|
tools: list[BaseTool] = builtin_result
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
"[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
len(tools),
|
len(tools),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load MCP tools if requested and dependencies are available
|
if mcp_task is not None:
|
||||||
if (
|
if isinstance(mcp_result, BaseException):
|
||||||
include_mcp_tools
|
# ``return_exceptions=True`` captures the exception out-of-band,
|
||||||
and "db_session" in dependencies
|
# so ``sys.exc_info()`` is empty here. Pass the captured
|
||||||
and "search_space_id" in dependencies
|
# exception via ``exc_info=`` to get a real traceback.
|
||||||
):
|
logging.error(
|
||||||
try:
|
"Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
|
||||||
_t0 = time.perf_counter()
|
|
||||||
mcp_tools = await load_mcp_tools(
|
|
||||||
dependencies["db_session"],
|
|
||||||
dependencies["search_space_id"],
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
mcp_tools = mcp_result or []
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
"[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
len(mcp_tools),
|
len(mcp_tools),
|
||||||
)
|
)
|
||||||
|
|
@ -879,8 +925,6 @@ async def build_tools_async(
|
||||||
len(mcp_tools),
|
len(mcp_tools),
|
||||||
[t.name for t in mcp_tools],
|
[t.name for t in mcp_tools],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Failed to load MCP tools: %s", e)
|
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"Total tools for agent: %d — %s",
|
"Total tools for agent: %d — %s",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
|
||||||
from app.utils.document_converters import embed_text
|
from app.utils.document_converters import embed_text
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
||||||
"""
|
"""
|
||||||
Factory function to create the search_surfsense_docs tool.
|
Factory function to create the search_surfsense_docs tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for executing queries
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured tool function for searching Surfsense documentation
|
A configured tool function for searching Surfsense documentation
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
||||||
|
|
@ -155,6 +162,7 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
||||||
Returns:
|
Returns:
|
||||||
Relevant documentation content formatted with chunk IDs for citations
|
Relevant documentation content formatted with chunk IDs for citations
|
||||||
"""
|
"""
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
return await search_surfsense_docs_async(
|
return await search_surfsense_docs_async(
|
||||||
query=query,
|
query=query,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_teams_channels tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_teams_channels tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_teams_channels() -> dict[str, Any]:
|
async def list_teams_channels() -> dict[str, Any]:
|
||||||
"""List all Microsoft Teams and their channels the user has access to.
|
"""List all Microsoft Teams and their channels the user has access to.
|
||||||
|
|
@ -23,11 +42,14 @@ def create_list_teams_channels_tool(
|
||||||
Dictionary with status and a list of teams, each containing
|
Dictionary with status and a list of teams, each containing
|
||||||
team_id, team_name, and a list of channels (id, name).
|
team_id, team_name, and a list of channels (id, name).
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_teams_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_teams_messages tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_teams_messages tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_teams_messages(
|
async def read_teams_messages(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
|
|
@ -32,13 +51,16 @@ def create_read_teams_messages_tool(
|
||||||
Dictionary with status and a list of messages including
|
Dictionary with status and a list of messages including
|
||||||
id, sender, content, timestamp.
|
id, sender, content, timestamp.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
limit = min(limit, 50)
|
limit = min(limit, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_teams_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_teams_message tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_teams_message tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_teams_message(
|
async def send_teams_message(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
|
|
@ -39,11 +57,14 @@ def create_send_teams_message_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
|
connector = await get_teams_connector(
|
||||||
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
if not connector:
|
if not connector:
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User
|
from app.db import SearchSpace, User, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -295,6 +295,25 @@ def create_update_memory_tool(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the user-memory update tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
The session's bound ``commit``/``rollback`` methods are captured at
|
||||||
|
call time, after ``async with`` has bound ``db_session`` locally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: ID of the user whose memory document is being updated.
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
llm: Optional LLM for the forced-rewrite path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_memory tool for the user-memory scope.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -311,6 +330,7 @@ def create_update_memory_tool(
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(select(User).where(User.id == uid))
|
result = await db_session.execute(select(User).where(User.id == uid))
|
||||||
user = result.scalars().first()
|
user = result.scalars().first()
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -330,7 +350,6 @@ def create_update_memory_tool(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update user memory: %s", e)
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"Failed to update memory: {e}",
|
"message": f"Failed to update memory: {e}",
|
||||||
|
|
@ -344,6 +363,27 @@ def create_update_team_memory_tool(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the team-memory update tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
The session's bound ``commit``/``rollback`` methods are captured at
|
||||||
|
call time, after ``async with`` has bound ``db_session`` locally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: ID of the search space whose team memory is being
|
||||||
|
updated.
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
llm: Optional LLM for the forced-rewrite path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_memory tool for the team-memory scope.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the team's shared memory document for this search space.
|
"""Update the team's shared memory document for this search space.
|
||||||
|
|
@ -359,6 +399,7 @@ def create_update_team_memory_tool(
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
)
|
)
|
||||||
|
|
@ -372,7 +413,9 @@ def create_update_team_memory_tool(
|
||||||
updated_memory=updated_memory,
|
updated_memory=updated_memory,
|
||||||
old_memory=old_memory,
|
old_memory=old_memory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
apply_fn=lambda content: setattr(
|
||||||
|
space, "shared_memory_md", content
|
||||||
|
),
|
||||||
commit_fn=db_session.commit,
|
commit_fn=db_session.commit,
|
||||||
rollback_fn=db_session.rollback,
|
rollback_fn=db_session.rollback,
|
||||||
label="team memory",
|
label="team memory",
|
||||||
|
|
@ -380,7 +423,6 @@ def create_update_team_memory_tool(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update team memory: %s", e)
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"Failed to update team memory: {e}",
|
"message": f"Failed to update team memory: {e}",
|
||||||
|
|
|
||||||
|
|
@ -421,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
|
||||||
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
||||||
|
|
||||||
|
|
||||||
|
async def _warm_agent_jit_caches() -> None:
|
||||||
|
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
|
||||||
|
|
||||||
|
Why
|
||||||
|
----
|
||||||
|
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
|
||||||
|
generation chain takes 1.5-2 seconds of pure CPU on first invocation
|
||||||
|
inside any Python process: the graph compiler builds reducers,
|
||||||
|
Pydantic v2 generates and JITs validator schemas, deepagents
|
||||||
|
eagerly compiles its general-purpose subagent, etc. Subsequent
|
||||||
|
compiles in the same process pay only ~50% of that cost (the lazy
|
||||||
|
JIT bits are cached in module-level dicts).
|
||||||
|
|
||||||
|
Doing one throwaway compile during ``lifespan`` startup pre-pays
|
||||||
|
that cost so the *first real request* doesn't. We do NOT prime
|
||||||
|
:mod:`agent_cache` because the cache key requires real
|
||||||
|
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the
|
||||||
|
throwaway agent is genuinely thrown away and immediately collected.
|
||||||
|
|
||||||
|
Safety
|
||||||
|
------
|
||||||
|
* No DB access. We construct a stub LLM (no real keys), pass an
|
||||||
|
empty tools list, and pass ``checkpointer=None`` so we never
|
||||||
|
touch Postgres.
|
||||||
|
* Bounded by ``asyncio.wait_for`` so a hang here can never block
|
||||||
|
worker startup. On any failure, we log + swallow — the worst
|
||||||
|
case is the first real request pays the full cold cost (i.e.
|
||||||
|
pre-warmup behaviour).
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
t0 = _time.perf_counter()
|
||||||
|
try:
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import (
|
||||||
|
ModelCallLimitMiddleware,
|
||||||
|
TodoListMiddleware,
|
||||||
|
ToolCallLimitMiddleware,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.fake_chat_models import (
|
||||||
|
FakeListChatModel,
|
||||||
|
)
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
|
|
||||||
|
# Minimal LLM stub. ``FakeListChatModel`` satisfies
|
||||||
|
# ``BaseChatModel`` without any network or auth — perfect for
|
||||||
|
# exercising the compile path without side effects.
|
||||||
|
stub_llm = FakeListChatModel(responses=["warmup-response"])
|
||||||
|
|
||||||
|
# Two trivial tools with arg + return schemas — exercises the
|
||||||
|
# Pydantic v2 schema JIT path. Without at least one tool the
|
||||||
|
# graph compile skips the tool-loop bytecode generation that
|
||||||
|
# accounts for ~30-50% of cold compile cost.
|
||||||
|
@tool
|
||||||
|
def _warmup_tool_a(query: str, limit: int = 5) -> str:
|
||||||
|
"""Warmup tool A — never actually invoked."""
|
||||||
|
return query[:limit]
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
|
||||||
|
"""Warmup tool B — never actually invoked."""
|
||||||
|
return {"name": name, "value": value}
|
||||||
|
|
||||||
|
# A handful of common middleware so the compile pre-pays the
|
||||||
|
# ``AgentMiddleware`` resolver path. These instances never run
|
||||||
|
# because the throwaway agent is immediately collected.
|
||||||
|
# ``SubAgentMiddleware`` is the single heaviest line in cold
|
||||||
|
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
|
||||||
|
# compile its general-purpose subagent's full inner graph),
|
||||||
|
# so we include it here to make sure that compile path is JIT'd.
|
||||||
|
warmup_middleware: list = [
|
||||||
|
TodoListMiddleware(),
|
||||||
|
ModelCallLimitMiddleware(
|
||||||
|
thread_limit=120, run_limit=80, exit_behavior="end"
|
||||||
|
),
|
||||||
|
ToolCallLimitMiddleware(
|
||||||
|
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
from deepagents import SubAgentMiddleware
|
||||||
|
from deepagents.backends import StateBackend
|
||||||
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
|
|
||||||
|
gp_warmup_spec = { # type: ignore[var-annotated]
|
||||||
|
**GENERAL_PURPOSE_SUBAGENT,
|
||||||
|
"model": stub_llm,
|
||||||
|
"tools": [_warmup_tool_a],
|
||||||
|
"middleware": [TodoListMiddleware()],
|
||||||
|
}
|
||||||
|
warmup_middleware.append(
|
||||||
|
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Deepagents missing/incompatible — middleware-only warmup
|
||||||
|
# still produces a useful (smaller) speedup.
|
||||||
|
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
|
||||||
|
|
||||||
|
compiled = create_agent(
|
||||||
|
stub_llm,
|
||||||
|
tools=[_warmup_tool_a, _warmup_tool_b],
|
||||||
|
system_prompt="You are a warmup stub.",
|
||||||
|
middleware=warmup_middleware,
|
||||||
|
context_schema=SurfSenseContextSchema,
|
||||||
|
checkpointer=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Touch the compiled graph's stream_channels / nodes so any
|
||||||
|
# remaining lazy schema work fires now instead of on first
|
||||||
|
# real invocation.
|
||||||
|
_ = list(getattr(compiled, "nodes", {}).keys())
|
||||||
|
|
||||||
|
del compiled
|
||||||
|
logger.info(
|
||||||
|
"[startup] Agent JIT warmup completed in %.3fs",
|
||||||
|
_time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
|
||||||
|
"real request will pay the full compile cost)",
|
||||||
|
_time.perf_counter() - t0,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||||
|
|
@ -445,6 +574,18 @@ async def lifespan(app: FastAPI):
|
||||||
"Docs will be indexed on the next restart."
|
"Docs will be indexed on the next restart."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||||
|
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||||
|
# doesn't leave half-warmed Pydantic schemas in an inconsistent
|
||||||
|
# state.
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
|
||||||
|
except (TimeoutError, Exception): # pragma: no cover - defensive
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"[startup] Agent JIT warmup hit timeout/error — skipping; "
|
||||||
|
"first real request will pay the full compile cost."
|
||||||
|
)
|
||||||
|
|
||||||
log_system_snapshot("startup_complete")
|
log_system_snapshot("startup_complete")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
|
||||||
|
|
@ -47,11 +47,37 @@ def load_global_llm_configs():
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
configs = data.get("global_llm_configs", [])
|
configs = data.get("global_llm_configs", [])
|
||||||
|
|
||||||
|
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||||
|
# and matches the `provider_api_base` pattern used elsewhere.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
seen_slugs: dict[str, int] = {}
|
seen_slugs: dict[str, int] = {}
|
||||||
for cfg in configs:
|
for cfg in configs:
|
||||||
cfg.setdefault("billing_tier", "free")
|
cfg.setdefault("billing_tier", "free")
|
||||||
cfg.setdefault("anonymous_enabled", False)
|
cfg.setdefault("anonymous_enabled", False)
|
||||||
cfg.setdefault("seo_enabled", False)
|
cfg.setdefault("seo_enabled", False)
|
||||||
|
# Capability flag: explicit YAML override always wins. When the
|
||||||
|
# operator has not annotated the model, defer to LiteLLM's
|
||||||
|
# authoritative model map (`supports_vision`) which already
|
||||||
|
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
|
||||||
|
# vision-capable. Unknown / unmapped models default-allow so
|
||||||
|
# we don't lock the user out of a freshly added third-party
|
||||||
|
# entry; the streaming-task safety net (driven by
|
||||||
|
# `is_known_text_only_chat_model`) is the only place a False
|
||||||
|
# actually blocks a request.
|
||||||
|
if "supports_image_input" not in cfg:
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
cfg["supports_image_input"] = derive_supports_image_input(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
||||||
slug = cfg["seo_slug"]
|
slug = cfg["seo_slug"]
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
|
from app.config import config
|
||||||
from app.db import User
|
from app.db import User
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
|
||||||
|
|
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
|
||||||
|
|
||||||
enable_otel: bool
|
enable_otel: bool
|
||||||
|
|
||||||
|
enable_desktop_local_filesystem: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
||||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||||
return cls(**asdict(flags))
|
return cls(
|
||||||
|
**asdict(flags),
|
||||||
|
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||||
|
|
|
||||||
|
|
@ -649,13 +649,9 @@ async def list_composio_drive_folders(
|
||||||
"""
|
"""
|
||||||
List folders AND files in user's Google Drive via Composio.
|
List folders AND files in user's Google Drive via Composio.
|
||||||
|
|
||||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
Uses Composio's Google Drive tool execution path so managed OAuth tokens
|
||||||
connector, with Composio-sourced credentials. This means auth errors
|
do not need to be exposed through connected account state.
|
||||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
|
||||||
"""
|
"""
|
||||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
if not ComposioService.is_enabled():
|
if not ComposioService.is_enabled():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503,
|
status_code=503,
|
||||||
|
|
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
|
||||||
detail="Composio connected account not found. Please reconnect the connector.",
|
detail="Composio connected account not found. Please reconnect the connector.",
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials = build_composio_credentials(composio_connected_account_id)
|
service = ComposioService()
|
||||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
entity_id = f"surfsense_{user.id}"
|
||||||
|
items = []
|
||||||
|
page_token = None
|
||||||
|
error = None
|
||||||
|
|
||||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
while True:
|
||||||
|
page_items, next_token, page_error = await service.get_drive_files(
|
||||||
|
connected_account_id=composio_connected_account_id,
|
||||||
|
entity_id=entity_id,
|
||||||
|
folder_id=parent_id,
|
||||||
|
page_token=page_token,
|
||||||
|
page_size=100,
|
||||||
|
)
|
||||||
|
if page_error:
|
||||||
|
error = page_error
|
||||||
|
break
|
||||||
|
|
||||||
|
items.extend(page_items)
|
||||||
|
if not next_token:
|
||||||
|
break
|
||||||
|
page_token = next_token
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
item["isFolder"] = (
|
||||||
|
item.get("mimeType") == "application/vnd.google-apps.folder"
|
||||||
|
)
|
||||||
|
|
||||||
|
items.sort(
|
||||||
|
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
|
||||||
|
)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
error_lower = error.lower()
|
error_lower = error.lower()
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ from app.services.image_gen_router_service import (
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
is_image_gen_auto_mode,
|
is_image_gen_auto_mode,
|
||||||
)
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
from app.utils.signed_image_urls import verify_image_token
|
from app.utils.signed_image_urls import verify_image_token
|
||||||
|
|
@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||||
|
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||||
|
if custom_provider:
|
||||||
|
return custom_provider
|
||||||
|
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||||
|
|
||||||
|
|
||||||
def _build_model_string(
|
def _build_model_string(
|
||||||
provider: str, model_name: str, custom_provider: str | None
|
provider: str, model_name: str, custom_provider: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a litellm model string from provider + model_name."""
|
"""Build a litellm model string from provider + model_name."""
|
||||||
if custom_provider:
|
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||||
return f"{custom_provider}/{model_name}"
|
|
||||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
||||||
return f"{prefix}/{model_name}"
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_billing_for_image_gen(
|
async def _resolve_billing_for_image_gen(
|
||||||
|
|
@ -187,12 +192,18 @@ async def _execute_image_generation(
|
||||||
if not cfg:
|
if not cfg:
|
||||||
raise ValueError(f"Global image generation config {config_id} not found")
|
raise ValueError(f"Global image generation config {config_id} not found")
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||||
if cfg.get("api_base"):
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = cfg["api_base"]
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=cfg.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if cfg.get("api_version"):
|
if cfg.get("api_version"):
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
gen_kwargs["api_version"] = cfg["api_version"]
|
||||||
if cfg.get("litellm_params"):
|
if cfg.get("litellm_params"):
|
||||||
|
|
@ -214,12 +225,18 @@ async def _execute_image_generation(
|
||||||
if not db_cfg:
|
if not db_cfg:
|
||||||
raise ValueError(f"Image generation config {config_id} not found")
|
raise ValueError(f"Image generation config {config_id} not found")
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
db_cfg.provider.value, db_cfg.custom_provider
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
gen_kwargs["api_key"] = db_cfg.api_key
|
||||||
if db_cfg.api_base:
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = db_cfg.api_base
|
provider=db_cfg.provider.value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=db_cfg.api_base,
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if db_cfg.api_version:
|
if db_cfg.api_version:
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
gen_kwargs["api_version"] = db_cfg.api_version
|
||||||
if db_cfg.litellm_params:
|
if db_cfg.litellm_params:
|
||||||
|
|
@ -277,10 +294,12 @@ async def get_global_image_gen_configs(
|
||||||
# Auto mode currently treated as free until per-deployment
|
# Auto mode currently treated as free until per-deployment
|
||||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||||
"billing_tier": "free",
|
"billing_tier": "free",
|
||||||
|
"is_premium": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for cfg in global_configs:
|
for cfg in global_configs:
|
||||||
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||||
safe_configs.append(
|
safe_configs.append(
|
||||||
{
|
{
|
||||||
"id": cfg.get("id"),
|
"id": cfg.get("id"),
|
||||||
|
|
@ -293,7 +312,11 @@ async def get_global_image_gen_configs(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"billing_tier": cfg.get("billing_tier", "free"),
|
"billing_tier": billing_tier,
|
||||||
|
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||||
|
# selector's premium badge logic keys off the same
|
||||||
|
# field across chat / image / vision tabs.
|
||||||
|
"is_premium": billing_tier == "premium",
|
||||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from app.schemas import (
|
||||||
NewLLMConfigUpdate,
|
NewLLMConfigUpdate,
|
||||||
)
|
)
|
||||||
from app.services.llm_service import validate_llm_config
|
from app.services.llm_service import validate_llm_config
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
|
|
||||||
|
|
@ -36,6 +37,39 @@ router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||||
|
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||||
|
|
||||||
|
There is no DB column for ``supports_image_input`` — the value is
|
||||||
|
resolved at the API boundary from LiteLLM's authoritative model map
|
||||||
|
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||||
|
the response shape consistent across list / detail / create / update
|
||||||
|
endpoints without having to remember to set the field at every call
|
||||||
|
site.
|
||||||
|
"""
|
||||||
|
provider_value = (
|
||||||
|
config.provider.value
|
||||||
|
if hasattr(config.provider, "value")
|
||||||
|
else str(config.provider)
|
||||||
|
)
|
||||||
|
litellm_params = config.litellm_params or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||||
|
)
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
)
|
||||||
|
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||||
|
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||||
|
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||||
|
# the surface immutable from the caller's perspective.
|
||||||
|
base_read = NewLLMConfigRead.model_validate(config)
|
||||||
|
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Global Configs Routes
|
# Global Configs Routes
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
|
||||||
"seo_title": None,
|
"seo_title": None,
|
||||||
"seo_description": None,
|
"seo_description": None,
|
||||||
"quota_reserve_tokens": None,
|
"quota_reserve_tokens": None,
|
||||||
|
# Auto routes across the configured pool, which usually
|
||||||
|
# includes at least one vision-capable deployment, so
|
||||||
|
# treat Auto as image-capable. The router itself will
|
||||||
|
# still pick a vision-capable deployment for messages
|
||||||
|
# carrying image_url blocks (LiteLLM Router falls back
|
||||||
|
# on ``404`` per its ``allowed_fails`` policy).
|
||||||
|
"supports_image_input": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add individual global configs
|
# Add individual global configs
|
||||||
for cfg in global_configs:
|
for cfg in global_configs:
|
||||||
|
# Capability resolution: explicit value (YAML override or OR
|
||||||
|
# `_supports_image_input(model)` payload baked in by the
|
||||||
|
# OpenRouter integration service) wins. Fall back to the
|
||||||
|
# LiteLLM-driven helper which default-allows on unknown so
|
||||||
|
# we don't hide vision-capable models that happen to lack a
|
||||||
|
# YAML annotation. The streaming task safety net is the
|
||||||
|
# only place a False ever blocks.
|
||||||
|
if "supports_image_input" in cfg:
|
||||||
|
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||||
|
else:
|
||||||
|
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
cfg_base_model = (
|
||||||
|
cfg_litellm_params.get("base_model")
|
||||||
|
if isinstance(cfg_litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=cfg_base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
|
||||||
safe_config = {
|
safe_config = {
|
||||||
"id": cfg.get("id"),
|
"id": cfg.get("id"),
|
||||||
"name": cfg.get("name"),
|
"name": cfg.get("name"),
|
||||||
|
|
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
|
||||||
"seo_title": cfg.get("seo_title"),
|
"seo_title": cfg.get("seo_title"),
|
||||||
"seo_description": cfg.get("seo_description"),
|
"seo_description": cfg.get("seo_description"),
|
||||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||||
|
"supports_image_input": supports_image_input,
|
||||||
}
|
}
|
||||||
safe_configs.append(safe_config)
|
safe_configs.append(safe_config)
|
||||||
|
|
||||||
|
|
@ -171,7 +236,7 @@ async def create_new_llm_config(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(db_config)
|
await session.refresh(db_config)
|
||||||
|
|
||||||
return db_config
|
return _serialize_byok_config(db_config)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -213,7 +278,7 @@ async def list_new_llm_configs(
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.scalars().all()
|
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -268,7 +333,7 @@ async def get_new_llm_config(
|
||||||
"You don't have permission to view LLM configurations in this search space",
|
"You don't have permission to view LLM configurations in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return _serialize_byok_config(config)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -360,7 +425,7 @@ async def update_new_llm_config(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(config)
|
await session.refresh(config)
|
||||||
|
|
||||||
return config
|
return _serialize_byok_config(config)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -85,10 +85,12 @@ async def get_global_vision_llm_configs(
|
||||||
# Auto mode treated as free until per-deployment billing-tier
|
# Auto mode treated as free until per-deployment billing-tier
|
||||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||||
"billing_tier": "free",
|
"billing_tier": "free",
|
||||||
|
"is_premium": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for cfg in global_configs:
|
for cfg in global_configs:
|
||||||
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||||
safe_configs.append(
|
safe_configs.append(
|
||||||
{
|
{
|
||||||
"id": cfg.get("id"),
|
"id": cfg.get("id"),
|
||||||
|
|
@ -101,7 +103,11 @@ async def get_global_vision_llm_configs(
|
||||||
"api_version": cfg.get("api_version") or None,
|
"api_version": cfg.get("api_version") or None,
|
||||||
"litellm_params": cfg.get("litellm_params", {}),
|
"litellm_params": cfg.get("litellm_params", {}),
|
||||||
"is_global": True,
|
"is_global": True,
|
||||||
"billing_tier": cfg.get("billing_tier", "free"),
|
"billing_tier": billing_tier,
|
||||||
|
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||||
|
# selector's premium badge logic keys off the same
|
||||||
|
# field across chat / image / vision tabs.
|
||||||
|
"is_premium": billing_tier == "premium",
|
||||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||||
|
|
|
||||||
|
|
@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel):
|
||||||
default="free",
|
default="free",
|
||||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||||
)
|
)
|
||||||
|
is_premium: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Convenience boolean derived server-side from "
|
||||||
|
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||||
|
"keys its Free/Premium badge off this field for parity with "
|
||||||
|
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||||
|
),
|
||||||
|
)
|
||||||
quota_reserve_micros: int | None = Field(
|
quota_reserve_micros: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
|
# Capability flag derived at the API boundary (no DB column). Default
|
||||||
|
# True matches the conservative-allow stance — a BYOK row that the
|
||||||
|
# route forgot to augment is not pre-judged. The streaming-task
|
||||||
|
# safety net is the only place a False actually blocks a request.
|
||||||
|
supports_image_input: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||||
|
"at the route boundary from LiteLLM's authoritative model map "
|
||||||
|
"(``litellm.supports_vision``) — there is no DB column. "
|
||||||
|
"Default True is the conservative-allow stance for unknown / "
|
||||||
|
"unmapped models."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
|
# Capability flag derived at the API boundary (see NewLLMConfigRead).
|
||||||
|
supports_image_input: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||||
|
"at the route boundary from LiteLLM's authoritative model map. "
|
||||||
|
"Default True is the conservative-allow stance."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
|
||||||
seo_title: str | None = None
|
seo_title: str | None = None
|
||||||
seo_description: str | None = None
|
seo_description: str | None = None
|
||||||
quota_reserve_tokens: int | None = None
|
quota_reserve_tokens: int | None = None
|
||||||
|
supports_image_input: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Whether the model accepts image inputs (multimodal vision). "
|
||||||
|
"Derived server-side: OpenRouter dynamic configs use "
|
||||||
|
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
|
||||||
|
"authoritative model map (``litellm.supports_vision``). The "
|
||||||
|
"new-chat selector hints with a 'No image' badge when this is "
|
||||||
|
"False and there are pending image attachments. The streaming "
|
||||||
|
"task fails fast only when LiteLLM *explicitly* marks a model "
|
||||||
|
"as text-only — unknown / unmapped models default-allow."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
||||||
default="free",
|
default="free",
|
||||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||||
)
|
)
|
||||||
|
is_premium: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Convenience boolean derived server-side from "
|
||||||
|
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||||
|
"keys its Free/Premium badge off this field for parity with "
|
||||||
|
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||||
|
),
|
||||||
|
)
|
||||||
quota_reserve_tokens: int | None = Field(
|
quota_reserve_tokens: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
|
|
|
||||||
|
|
@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
|
||||||
_healthy_until.pop(int(config_id), None)
|
_healthy_until.pop(int(config_id), None)
|
||||||
|
|
||||||
|
|
||||||
def _global_candidates() -> list[dict]:
|
def _cfg_supports_image_input(cfg: dict) -> bool:
|
||||||
|
"""True if the global cfg can accept image inputs.
|
||||||
|
|
||||||
|
Prefers the explicit ``supports_image_input`` flag (set by the YAML
|
||||||
|
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
|
||||||
|
a YAML entry whose flag was somehow stripped doesn't get wrongly
|
||||||
|
excluded. Default-allows on unknown — the streaming-task safety net
|
||||||
|
is the actual block, not this filter.
|
||||||
|
"""
|
||||||
|
if "supports_image_input" in cfg:
|
||||||
|
return bool(cfg.get("supports_image_input"))
|
||||||
|
# Lazy import: provider_capabilities -> llm_config -> services chain;
|
||||||
|
# importing at module load would create an init-order cycle through
|
||||||
|
# ``app.config``.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
cfg_litellm_params.get("base_model")
|
||||||
|
if isinstance(cfg_litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return derive_supports_image_input(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||||
"""Return Auto-eligible global cfgs.
|
"""Return Auto-eligible global cfgs.
|
||||||
|
|
||||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||||
can't be picked as the thread's pin. Also excludes configs currently
|
can't be picked as the thread's pin. Also excludes configs currently
|
||||||
in runtime cooldown (e.g. temporary 429 bursts).
|
in runtime cooldown (e.g. temporary 429 bursts).
|
||||||
|
|
||||||
|
When ``requires_image_input`` is True (image turn), additionally
|
||||||
|
filters out configs whose ``supports_image_input`` resolves to False
|
||||||
|
so a text-only deployment can't be pinned for an image request.
|
||||||
"""
|
"""
|
||||||
candidates = [
|
candidates = [
|
||||||
cfg
|
cfg
|
||||||
|
|
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
|
||||||
if _is_usable_global_config(cfg)
|
if _is_usable_global_config(cfg)
|
||||||
and not cfg.get("health_gated")
|
and not cfg.get("health_gated")
|
||||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||||
|
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||||
]
|
]
|
||||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||||
|
|
||||||
|
|
@ -185,6 +220,15 @@ def _tier_of(cfg: dict) -> str:
|
||||||
return str(cfg.get("billing_tier", "free")).lower()
|
return str(cfg.get("billing_tier", "free")).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||||
|
"""Return True for the operator-preferred premium Auto model."""
|
||||||
|
return (
|
||||||
|
_tier_of(cfg) == "premium"
|
||||||
|
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
|
||||||
|
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||||
"""Pick a config with quality-first ranking + deterministic spread.
|
"""Pick a config with quality-first ranking + deterministic spread.
|
||||||
|
|
||||||
|
|
@ -237,11 +281,20 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
selected_llm_config_id: int,
|
selected_llm_config_id: int,
|
||||||
force_repin_free: bool = False,
|
force_repin_free: bool = False,
|
||||||
exclude_config_ids: set[int] | None = None,
|
exclude_config_ids: set[int] | None = None,
|
||||||
|
requires_image_input: bool = False,
|
||||||
) -> AutoPinResolution:
|
) -> AutoPinResolution:
|
||||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||||
|
|
||||||
For non-auto selections, this function clears any existing pin and returns
|
For non-auto selections, this function clears any existing pin and returns
|
||||||
the selected id as-is.
|
the selected id as-is.
|
||||||
|
|
||||||
|
When ``requires_image_input`` is True (the current turn carries an
|
||||||
|
``image_url`` block), the candidate pool is filtered to vision-capable
|
||||||
|
cfgs and any existing pin that can't accept image input is treated as
|
||||||
|
invalid (force re-pin). If no vision-capable cfg is available the
|
||||||
|
function raises ``ValueError`` so the streaming task surfaces the same
|
||||||
|
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
|
||||||
|
silently routing the image to a text-only deployment.
|
||||||
"""
|
"""
|
||||||
thread = (
|
thread = (
|
||||||
(
|
(
|
||||||
|
|
@ -274,14 +327,24 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
|
|
||||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||||
candidates = [
|
candidates = [
|
||||||
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
c
|
||||||
|
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||||
|
if int(c.get("id", 0)) not in excluded_ids
|
||||||
]
|
]
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
if requires_image_input:
|
||||||
|
# Distinguish the "no vision-capable cfg" case from generic
|
||||||
|
# "no usable cfg" so the streaming task can map this to the
|
||||||
|
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||||
|
raise ValueError(
|
||||||
|
"No vision-capable global LLM configs are available for Auto mode"
|
||||||
|
)
|
||||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||||
|
|
||||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||||
# tier switch), unless the caller explicitly requests a forced repin to free.
|
# tier switch), unless the caller explicitly requests a forced repin to free
|
||||||
|
# *or* the turn requires image input but the pin can't handle it.
|
||||||
pinned_id = thread.pinned_llm_config_id
|
pinned_id = thread.pinned_llm_config_id
|
||||||
if (
|
if (
|
||||||
not force_repin_free
|
not force_repin_free
|
||||||
|
|
@ -311,6 +374,29 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
from_existing_pin=True,
|
from_existing_pin=True,
|
||||||
)
|
)
|
||||||
if pinned_id is not None:
|
if pinned_id is not None:
|
||||||
|
# If the pin is *only* invalid because it can't handle the image
|
||||||
|
# turn (it's still a healthy, usable config in the broader pool),
|
||||||
|
# log that explicitly so operators can correlate the re-pin with
|
||||||
|
# the user's image attachment instead of suspecting a cooldown.
|
||||||
|
if requires_image_input:
|
||||||
|
try:
|
||||||
|
pinned_global = next(
|
||||||
|
c
|
||||||
|
for c in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if int(c.get("id", 0)) == int(pinned_id)
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
pinned_global = None
|
||||||
|
if pinned_global is not None and not _cfg_supports_image_input(
|
||||||
|
pinned_global
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||||
|
"previous_config_id=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
pinned_id,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||||
thread_id,
|
thread_id,
|
||||||
|
|
@ -322,11 +408,19 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||||
)
|
)
|
||||||
if premium_eligible:
|
if premium_eligible:
|
||||||
eligible = candidates
|
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
|
||||||
|
preferred_premium = [
|
||||||
|
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
|
||||||
|
]
|
||||||
|
eligible = preferred_premium or premium_candidates
|
||||||
else:
|
else:
|
||||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||||
|
|
||||||
if not eligible:
|
if not eligible:
|
||||||
|
if requires_image_input:
|
||||||
|
raise ValueError(
|
||||||
|
"Auto mode could not find a vision-capable LLM config for this user and quota state"
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||||
|
|
||||||
KEY DESIGN POINTS (issue A, B):
|
KEY DESIGN POINTS (issue A, B):
|
||||||
|
|
||||||
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
|
1. **Session isolation.** ``billable_call`` takes no caller transaction.
|
||||||
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
|
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
|
||||||
insert each run inside their own ``shielded_async_session()``. This
|
inside their own session context. Route callers use
|
||||||
guarantees that a quota commit/rollback can never accidentally flush or
|
``shielded_async_session()`` by default; Celery callers can provide a
|
||||||
roll back rows the caller has staged in the request's main session
|
worker-loop-safe session factory. This guarantees that quota
|
||||||
(e.g. a freshly-created ``ImageGeneration`` row).
|
commit/rollback can never accidentally flush or roll back rows the caller
|
||||||
|
has staged in its main session (e.g. a freshly-created
|
||||||
|
``ImageGeneration`` row).
|
||||||
|
|
||||||
2. **ContextVar safety.** The accumulator is scoped via
|
2. **ContextVar safety.** The accumulator is scoped via
|
||||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||||
|
|
@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B):
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator, Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
|
@ -58,6 +61,12 @@ from app.services.token_tracking_service import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AUDIT_TIMEOUT_SECONDS = 10.0
|
||||||
|
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
|
||||||
|
{"video_presentation_generation", "podcast_generation"}
|
||||||
|
)
|
||||||
|
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
|
||||||
|
|
||||||
|
|
||||||
class QuotaInsufficientError(Exception):
|
class QuotaInsufficientError(Exception):
|
||||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||||
|
|
@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillingSettlementError(Exception):
|
||||||
|
"""Raised when a premium call completed but credit settlement failed."""
|
||||||
|
|
||||||
|
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
|
||||||
|
self.usage_type = usage_type
|
||||||
|
self.user_id = user_id
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _rollback_safely(session: AsyncSession) -> None:
|
||||||
|
rollback = getattr(session, "rollback", None)
|
||||||
|
if rollback is not None:
|
||||||
|
with suppress(Exception):
|
||||||
|
await rollback()
|
||||||
|
|
||||||
|
|
||||||
|
async def _record_audit_best_effort(
|
||||||
|
*,
|
||||||
|
session_factory: BillableSessionFactory,
|
||||||
|
usage_type: str,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: UUID,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
cost_micros: int,
|
||||||
|
model_breakdown: dict[str, Any],
|
||||||
|
call_details: dict[str, Any] | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
message_id: int | None,
|
||||||
|
audit_label: str,
|
||||||
|
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Persist a TokenUsage row without letting audit failure block callers.
|
||||||
|
|
||||||
|
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
|
||||||
|
audit insert or commit hangs, user-facing artifacts such as videos and
|
||||||
|
podcasts must still be able to transition to READY after settlement.
|
||||||
|
"""
|
||||||
|
audit_thread_id = (
|
||||||
|
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _persist() -> None:
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
|
||||||
|
"total_tokens=%d cost_micros=%d",
|
||||||
|
audit_label,
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
audit_thread_id,
|
||||||
|
total_tokens,
|
||||||
|
cost_micros,
|
||||||
|
)
|
||||||
|
async with session_factory() as audit_session:
|
||||||
|
try:
|
||||||
|
await record_token_usage(
|
||||||
|
audit_session,
|
||||||
|
usage_type=usage_type,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cost_micros=cost_micros,
|
||||||
|
model_breakdown=model_breakdown,
|
||||||
|
call_details=call_details,
|
||||||
|
thread_id=audit_thread_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
|
||||||
|
audit_label,
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
audit_thread_id,
|
||||||
|
)
|
||||||
|
await audit_session.commit()
|
||||||
|
logger.info(
|
||||||
|
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
|
||||||
|
audit_label,
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
audit_thread_id,
|
||||||
|
)
|
||||||
|
except BaseException:
|
||||||
|
await _rollback_safely(audit_session)
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
|
||||||
|
"timeout=%.1fs total_tokens=%d cost_micros=%d",
|
||||||
|
audit_label,
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
audit_thread_id,
|
||||||
|
timeout_seconds,
|
||||||
|
total_tokens,
|
||||||
|
cost_micros,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
|
||||||
|
"total_tokens=%d cost_micros=%d",
|
||||||
|
audit_label,
|
||||||
|
usage_type,
|
||||||
|
user_id,
|
||||||
|
audit_thread_id,
|
||||||
|
total_tokens,
|
||||||
|
cost_micros,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def billable_call(
|
async def billable_call(
|
||||||
*,
|
*,
|
||||||
|
|
@ -101,6 +228,8 @@ async def billable_call(
|
||||||
thread_id: int | None = None,
|
thread_id: int | None = None,
|
||||||
message_id: int | None = None,
|
message_id: int | None = None,
|
||||||
call_details: dict[str, Any] | None = None,
|
call_details: dict[str, Any] | None = None,
|
||||||
|
billable_session_factory: BillableSessionFactory | None = None,
|
||||||
|
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||||
"""Wrap a single billable LLM/image call.
|
"""Wrap a single billable LLM/image call.
|
||||||
|
|
||||||
|
|
@ -124,6 +253,13 @@ async def billable_call(
|
||||||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||||
call_details: Optional per-call metadata (model name, parameters)
|
call_details: Optional per-call metadata (model name, parameters)
|
||||||
forwarded to ``record_token_usage``.
|
forwarded to ``record_token_usage``.
|
||||||
|
billable_session_factory: Optional async context factory used for
|
||||||
|
reserve/finalize/release/audit sessions. Defaults to
|
||||||
|
``shielded_async_session`` for route callers; Celery callers pass
|
||||||
|
a worker-loop-safe session factory.
|
||||||
|
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
|
||||||
|
Audit failure is best-effort and does not undo successful
|
||||||
|
settlement.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||||
|
|
@ -134,6 +270,7 @@ async def billable_call(
|
||||||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||||
"""
|
"""
|
||||||
is_premium = billing_tier == "premium"
|
is_premium = billing_tier == "premium"
|
||||||
|
session_factory = billable_session_factory or shielded_async_session
|
||||||
|
|
||||||
async with scoped_turn() as acc:
|
async with scoped_turn() as acc:
|
||||||
# ---------- Free path: just audit -------------------------------
|
# ---------- Free path: just audit -------------------------------
|
||||||
|
|
@ -143,10 +280,8 @@ async def billable_call(
|
||||||
finally:
|
finally:
|
||||||
# Always audit, even on exception, so we capture cost when
|
# Always audit, even on exception, so we capture cost when
|
||||||
# provider returns successfully but the caller raises later.
|
# provider returns successfully but the caller raises later.
|
||||||
try:
|
await _record_audit_best_effort(
|
||||||
async with shielded_async_session() as audit_session:
|
session_factory=session_factory,
|
||||||
await record_token_usage(
|
|
||||||
audit_session,
|
|
||||||
usage_type=usage_type,
|
usage_type=usage_type,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -158,14 +293,8 @@ async def billable_call(
|
||||||
call_details=call_details,
|
call_details=call_details,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
audit_label="free",
|
||||||
await audit_session.commit()
|
timeout_seconds=audit_timeout_seconds,
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"[billable_call] free-path audit insert failed for "
|
|
||||||
"usage_type=%s user_id=%s",
|
|
||||||
usage_type,
|
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -180,7 +309,7 @@ async def billable_call(
|
||||||
|
|
||||||
request_id = str(uuid4())
|
request_id = str(uuid4())
|
||||||
|
|
||||||
async with shielded_async_session() as quota_session:
|
async with session_factory() as quota_session:
|
||||||
reserve_result = await TokenQuotaService.premium_reserve(
|
reserve_result = await TokenQuotaService.premium_reserve(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -222,7 +351,7 @@ async def billable_call(
|
||||||
# from a downstream call, asyncio cancellation, etc.). We use
|
# from a downstream call, asyncio cancellation, etc.). We use
|
||||||
# BaseException so cancellation also releases.
|
# BaseException so cancellation also releases.
|
||||||
try:
|
try:
|
||||||
async with shielded_async_session() as quota_session:
|
async with session_factory() as quota_session:
|
||||||
await TokenQuotaService.premium_release(
|
await TokenQuotaService.premium_release(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -241,7 +370,16 @@ async def billable_call(
|
||||||
# ---------- Success: finalize + audit ----------------------------
|
# ---------- Success: finalize + audit ----------------------------
|
||||||
actual_micros = acc.total_cost_micros
|
actual_micros = acc.total_cost_micros
|
||||||
try:
|
try:
|
||||||
async with shielded_async_session() as quota_session:
|
logger.info(
|
||||||
|
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
|
||||||
|
"reserved=%d thread=%s",
|
||||||
|
user_id,
|
||||||
|
usage_type,
|
||||||
|
actual_micros,
|
||||||
|
reserve_micros,
|
||||||
|
thread_id,
|
||||||
|
)
|
||||||
|
async with session_factory() as quota_session:
|
||||||
final_result = await TokenQuotaService.premium_finalize(
|
final_result = await TokenQuotaService.premium_finalize(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -260,7 +398,7 @@ async def billable_call(
|
||||||
final_result.limit,
|
final_result.limit,
|
||||||
final_result.remaining,
|
final_result.remaining,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as finalize_exc:
|
||||||
# Last-ditch: if finalize itself fails, we must at least release
|
# Last-ditch: if finalize itself fails, we must at least release
|
||||||
# so the reservation doesn't leak.
|
# so the reservation doesn't leak.
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
|
@ -269,7 +407,7 @@ async def billable_call(
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async with shielded_async_session() as quota_session:
|
async with session_factory() as quota_session:
|
||||||
await TokenQuotaService.premium_release(
|
await TokenQuotaService.premium_release(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -281,11 +419,14 @@ async def billable_call(
|
||||||
"for user=%s",
|
"for user=%s",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
raise BillingSettlementError(
|
||||||
|
usage_type=usage_type,
|
||||||
|
user_id=user_id,
|
||||||
|
cause=finalize_exc,
|
||||||
|
) from finalize_exc
|
||||||
|
|
||||||
try:
|
await _record_audit_best_effort(
|
||||||
async with shielded_async_session() as audit_session:
|
session_factory=session_factory,
|
||||||
await record_token_usage(
|
|
||||||
audit_session,
|
|
||||||
usage_type=usage_type,
|
usage_type=usage_type,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -297,14 +438,8 @@ async def billable_call(
|
||||||
call_details=call_details,
|
call_details=call_details,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
audit_label="premium",
|
||||||
await audit_session.commit()
|
timeout_seconds=audit_timeout_seconds,
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"[billable_call] premium-path audit insert failed for "
|
|
||||||
"usage_type=%s user_id=%s (debit was applied)",
|
|
||||||
usage_type,
|
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space(
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BillingSettlementError",
|
||||||
"QuotaInsufficientError",
|
"QuotaInsufficientError",
|
||||||
"_resolve_agent_billing_for_search_space",
|
"_resolve_agent_billing_for_search_space",
|
||||||
"billable_call",
|
"billable_call",
|
||||||
|
|
|
||||||
|
|
@ -408,12 +408,37 @@ class ComposioService:
|
||||||
files = []
|
files = []
|
||||||
next_token = None
|
next_token = None
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
# Try direct access first, then nested
|
# Try direct access first, then nested
|
||||||
files = data.get("files", []) or data.get("data", {}).get("files", [])
|
files = (
|
||||||
|
data.get("files", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("files", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("files", [])
|
||||||
|
)
|
||||||
next_token = (
|
next_token = (
|
||||||
data.get("nextPageToken")
|
data.get("nextPageToken")
|
||||||
or data.get("next_page_token")
|
or data.get("next_page_token")
|
||||||
or data.get("data", {}).get("nextPageToken")
|
or (
|
||||||
|
inner_data.get("nextPageToken")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("next_page_token")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("nextPageToken")
|
||||||
|
or response_data.get("next_page_token")
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
files = data
|
files = data
|
||||||
|
|
@ -819,24 +844,61 @@ class ComposioService:
|
||||||
next_token = None
|
next_token = None
|
||||||
result_size_estimate = None
|
result_size_estimate = None
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
messages = (
|
messages = (
|
||||||
data.get("messages", [])
|
data.get("messages", [])
|
||||||
or data.get("data", {}).get("messages", [])
|
or (
|
||||||
|
inner_data.get("messages", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("messages", [])
|
||||||
or data.get("emails", [])
|
or data.get("emails", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("emails", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("emails", [])
|
||||||
)
|
)
|
||||||
# Check for pagination token in various possible locations
|
# Check for pagination token in various possible locations
|
||||||
next_token = (
|
next_token = (
|
||||||
data.get("nextPageToken")
|
data.get("nextPageToken")
|
||||||
or data.get("next_page_token")
|
or data.get("next_page_token")
|
||||||
or data.get("data", {}).get("nextPageToken")
|
or (
|
||||||
or data.get("data", {}).get("next_page_token")
|
inner_data.get("nextPageToken")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("next_page_token")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("nextPageToken")
|
||||||
|
or response_data.get("next_page_token")
|
||||||
)
|
)
|
||||||
# Extract resultSizeEstimate if available (Gmail API provides this)
|
# Extract resultSizeEstimate if available (Gmail API provides this)
|
||||||
result_size_estimate = (
|
result_size_estimate = (
|
||||||
data.get("resultSizeEstimate")
|
data.get("resultSizeEstimate")
|
||||||
or data.get("result_size_estimate")
|
or data.get("result_size_estimate")
|
||||||
or data.get("data", {}).get("resultSizeEstimate")
|
or (
|
||||||
or data.get("data", {}).get("result_size_estimate")
|
inner_data.get("resultSizeEstimate")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
inner_data.get("result_size_estimate")
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
or response_data.get("resultSizeEstimate")
|
||||||
|
or response_data.get("result_size_estimate")
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
messages = data
|
messages = data
|
||||||
|
|
@ -864,7 +926,7 @@ class ComposioService:
|
||||||
try:
|
try:
|
||||||
result = await self.execute_tool(
|
result = await self.execute_tool(
|
||||||
connected_account_id=connected_account_id,
|
connected_account_id=connected_account_id,
|
||||||
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
|
tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
|
||||||
params={"message_id": message_id}, # snake_case
|
params={"message_id": message_id}, # snake_case
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
)
|
)
|
||||||
|
|
@ -872,7 +934,13 @@ class ComposioService:
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
return None, result.get("error", "Unknown error")
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
return result.get("data"), None
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
return inner_data.get("response_data", inner_data), None
|
||||||
|
|
||||||
|
return data, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
||||||
|
|
@ -928,10 +996,27 @@ class ComposioService:
|
||||||
# Try different possible response structures
|
# Try different possible response structures
|
||||||
events = []
|
events = []
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
response_data = (
|
||||||
|
inner_data.get("response_data", {})
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
events = (
|
events = (
|
||||||
data.get("items", [])
|
data.get("items", [])
|
||||||
or data.get("data", {}).get("items", [])
|
or (
|
||||||
|
inner_data.get("items", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("items", [])
|
||||||
or data.get("events", [])
|
or data.get("events", [])
|
||||||
|
or (
|
||||||
|
inner_data.get("events", [])
|
||||||
|
if isinstance(inner_data, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
or response_data.get("events", [])
|
||||||
)
|
)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
events = data
|
events = data
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from threading import Lock
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -2769,12 +2771,22 @@ class ConnectorService:
|
||||||
"""
|
"""
|
||||||
Get all available (enabled) connector types for a search space.
|
Get all available (enabled) connector types for a search space.
|
||||||
|
|
||||||
|
Phase 1.4: results are cached per ``search_space_id`` for
|
||||||
|
:data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
|
||||||
|
identity — the cached value is plain data, safe to share across
|
||||||
|
requests. Invalidate on connector add/update/delete via
|
||||||
|
:func:`invalidate_connector_discovery_cache`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The search space ID
|
search_space_id: The search space ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of SearchSourceConnectorType enums for enabled connectors
|
List of SearchSourceConnectorType enums for enabled connectors
|
||||||
"""
|
"""
|
||||||
|
cached = _get_cached_connectors(search_space_id)
|
||||||
|
if cached is not None:
|
||||||
|
return list(cached)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
select(SearchSourceConnector.connector_type)
|
select(SearchSourceConnector.connector_type)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -2784,8 +2796,9 @@ class ConnectorService:
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.session.execute(query)
|
result = await self.session.execute(query)
|
||||||
connector_types = result.scalars().all()
|
connector_types = list(result.scalars().all())
|
||||||
return list(connector_types)
|
_set_cached_connectors(search_space_id, connector_types)
|
||||||
|
return connector_types
|
||||||
|
|
||||||
async def get_available_document_types(
|
async def get_available_document_types(
|
||||||
self,
|
self,
|
||||||
|
|
@ -2794,12 +2807,22 @@ class ConnectorService:
|
||||||
"""
|
"""
|
||||||
Get all document types that have at least one document in the search space.
|
Get all document types that have at least one document in the search space.
|
||||||
|
|
||||||
|
Phase 1.4: cached per ``search_space_id`` for
|
||||||
|
:data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
|
||||||
|
:func:`invalidate_connector_discovery_cache` when a connector
|
||||||
|
finishes indexing new documents (or document types are otherwise
|
||||||
|
added/removed).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The search space ID
|
search_space_id: The search space ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of document type strings that have documents indexed
|
List of document type strings that have documents indexed
|
||||||
"""
|
"""
|
||||||
|
cached = _get_cached_doc_types(search_space_id)
|
||||||
|
if cached is not None:
|
||||||
|
return list(cached)
|
||||||
|
|
||||||
from sqlalchemy import distinct
|
from sqlalchemy import distinct
|
||||||
|
|
||||||
from app.db import Document
|
from app.db import Document
|
||||||
|
|
@ -2809,5 +2832,164 @@ class ConnectorService:
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.session.execute(query)
|
result = await self.session.execute(query)
|
||||||
doc_types = result.scalars().all()
|
doc_types = [str(dt) for dt in result.scalars().all()]
|
||||||
return [str(dt) for dt in doc_types]
|
_set_cached_doc_types(search_space_id, doc_types)
|
||||||
|
return doc_types
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Connector / document-type discovery TTL cache (Phase 1.4)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Both ``get_available_connectors`` and ``get_available_document_types`` are
|
||||||
|
# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
|
||||||
|
# hits Postgres and contributes to per-turn agent build latency. Their
|
||||||
|
# results change infrequently — only when the user adds/edits/removes a
|
||||||
|
# connector, or when an indexer commits a new document type. A short TTL
|
||||||
|
# cache (default 30s, env-tunable) collapses N concurrent calls into one
|
||||||
|
# DB roundtrip with bounded staleness.
|
||||||
|
#
|
||||||
|
# Invalidation: connector mutation routes (create / update / delete) call
|
||||||
|
# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
|
||||||
|
# entry for the affected space. Multi-replica deployments still pay one
|
||||||
|
# DB roundtrip per replica per TTL window, which is fine — staleness is
|
||||||
|
# bounded and the alternative (cross-replica fanout) is not worth the
|
||||||
|
# coupling here.
|
||||||
|
|
||||||
|
_DISCOVERY_TTL_SECONDS: float = float(
|
||||||
|
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-search-space caches. Keyed by ``search_space_id``; value is
|
||||||
|
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
|
||||||
|
# read-mostly workload, sub-microsecond contention.
|
||||||
|
_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
|
||||||
|
_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
|
||||||
|
_cache_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cached_connectors(
|
||||||
|
search_space_id: int,
|
||||||
|
) -> list[SearchSourceConnectorType] | None:
|
||||||
|
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||||
|
return None
|
||||||
|
with _cache_lock:
|
||||||
|
entry = _connectors_cache.get(search_space_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
expires_at, payload = entry
|
||||||
|
if time.monotonic() >= expires_at:
|
||||||
|
_connectors_cache.pop(search_space_id, None)
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _set_cached_connectors(
|
||||||
|
search_space_id: int, payload: list[SearchSourceConnectorType]
|
||||||
|
) -> None:
|
||||||
|
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||||
|
return
|
||||||
|
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||||
|
with _cache_lock:
|
||||||
|
_connectors_cache[search_space_id] = (expires_at, list(payload))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
|
||||||
|
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||||
|
return None
|
||||||
|
with _cache_lock:
|
||||||
|
entry = _doc_types_cache.get(search_space_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
expires_at, payload = entry
|
||||||
|
if time.monotonic() >= expires_at:
|
||||||
|
_doc_types_cache.pop(search_space_id, None)
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
|
||||||
|
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||||
|
return
|
||||||
|
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||||
|
with _cache_lock:
|
||||||
|
_doc_types_cache[search_space_id] = (expires_at, list(payload))
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
|
||||||
|
"""Drop cached discovery results for ``search_space_id`` (or all spaces).
|
||||||
|
|
||||||
|
Connector CRUD routes / indexer pipelines call this when they mutate
|
||||||
|
the rows backing :func:`ConnectorService.get_available_connectors` /
|
||||||
|
:func:`get_available_document_types`. ``None`` clears every space —
|
||||||
|
useful in tests and on bulk imports.
|
||||||
|
"""
|
||||||
|
with _cache_lock:
|
||||||
|
if search_space_id is None:
|
||||||
|
_connectors_cache.clear()
|
||||||
|
_doc_types_cache.clear()
|
||||||
|
else:
|
||||||
|
_connectors_cache.pop(search_space_id, None)
|
||||||
|
_doc_types_cache.pop(search_space_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
|
||||||
|
with _cache_lock:
|
||||||
|
if search_space_id is None:
|
||||||
|
_connectors_cache.clear()
|
||||||
|
else:
|
||||||
|
_connectors_cache.pop(search_space_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
|
||||||
|
with _cache_lock:
|
||||||
|
if search_space_id is None:
|
||||||
|
_doc_types_cache.clear()
|
||||||
|
else:
|
||||||
|
_doc_types_cache.pop(search_space_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_invalidation_listeners() -> None:
|
||||||
|
"""Wire SQLAlchemy ORM events so cache stays consistent automatically.
|
||||||
|
|
||||||
|
Listening on ``after_insert`` / ``after_update`` / ``after_delete``
|
||||||
|
means every successful INSERT/UPDATE/DELETE that goes through the ORM
|
||||||
|
invalidates the affected search space's cached discovery payload —
|
||||||
|
no need to sprinkle ``invalidate_*`` calls across 30+ connector
|
||||||
|
routes. Bulk operations that bypass the ORM (e.g.
|
||||||
|
``session.execute(insert(...))`` without a mapped object) still need
|
||||||
|
explicit invalidation; document indexers already commit through the
|
||||||
|
ORM so document-type discovery is covered.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import event
|
||||||
|
|
||||||
|
# Imported here (not at module top) to avoid a circular import:
|
||||||
|
# app.services.connector_service is itself imported from app.db's
|
||||||
|
# ecosystem indirectly via several CRUD modules.
|
||||||
|
from app.db import Document, SearchSourceConnector
|
||||||
|
|
||||||
|
def _connector_changed(_mapper, _connection, target) -> None:
|
||||||
|
sid = getattr(target, "search_space_id", None)
|
||||||
|
if sid is not None:
|
||||||
|
_invalidate_connectors_only(int(sid))
|
||||||
|
|
||||||
|
def _document_changed(_mapper, _connection, target) -> None:
|
||||||
|
sid = getattr(target, "search_space_id", None)
|
||||||
|
if sid is not None:
|
||||||
|
_invalidate_doc_types_only(int(sid))
|
||||||
|
|
||||||
|
for evt in ("after_insert", "after_update", "after_delete"):
|
||||||
|
event.listen(SearchSourceConnector, evt, _connector_changed)
|
||||||
|
event.listen(Document, evt, _document_changed)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
_register_invalidation_listeners()
|
||||||
|
except Exception: # pragma: no cover - defensive; never block module import
|
||||||
|
import logging as _logging
|
||||||
|
|
||||||
|
_logging.getLogger(__name__).exception(
|
||||||
|
"Failed to register connector discovery cache invalidation listeners; "
|
||||||
|
"stale cache risk: explicit invalidate_connector_discovery_cache calls "
|
||||||
|
"may be required."
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -78,14 +78,49 @@ class GmailToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
if (
|
return (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
):
|
)
|
||||||
|
|
||||||
|
def _get_composio_connected_account_id(
|
||||||
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
return build_composio_credentials(cca_id)
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
def _unwrap_composio_data(self, data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner)
|
||||||
|
return inner
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _execute_composio_gmail_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Any, str | None]:
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Gmail error")
|
||||||
|
return self._unwrap_composio_data(result.get("data")), None
|
||||||
|
|
||||||
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
raise ValueError(
|
||||||
|
"Composio Gmail connectors must use Composio tool execution"
|
||||||
|
)
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
@ -139,6 +174,12 @@ class GmailToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
_profile, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
service = build("gmail", "v1", credentials=creds)
|
service = build("gmail", "v1", credentials=creds)
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
|
@ -221,6 +262,13 @@ class GmailToolMetadataService:
|
||||||
)
|
)
|
||||||
connector = result.scalar_one_or_none()
|
connector = result.scalar_one_or_none()
|
||||||
if connector:
|
if connector:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
profile, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
service = build("gmail", "v1", credentials=creds)
|
service = build("gmail", "v1", credentials=creds)
|
||||||
profile = await asyncio.get_event_loop().run_in_executor(
|
profile = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
|
@ -298,6 +346,23 @@ class GmailToolMetadataService:
|
||||||
Returns ``None`` on any failure so callers can degrade gracefully.
|
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
if not draft_id:
|
||||||
|
draft_id = await self._find_composio_draft_id(connector, message_id)
|
||||||
|
if not draft_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
draft, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
"GMAIL_GET_DRAFT",
|
||||||
|
{"user_id": "me", "draft_id": draft_id, "format": "full"},
|
||||||
|
)
|
||||||
|
if error or not isinstance(draft, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = draft.get("message", {}).get("payload", {})
|
||||||
|
return self._extract_body_from_payload(payload)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
service = build("gmail", "v1", credentials=creds)
|
service = build("gmail", "v1", credentials=creds)
|
||||||
|
|
||||||
|
|
@ -326,6 +391,33 @@ class GmailToolMetadataService:
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _find_composio_draft_id(
|
||||||
|
self, connector: SearchSourceConnector, message_id: str
|
||||||
|
) -> str | None:
|
||||||
|
page_token = ""
|
||||||
|
while True:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"user_id": "me",
|
||||||
|
"max_results": 100,
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
data, error = await self._execute_composio_gmail_tool(
|
||||||
|
connector, "GMAIL_LIST_DRAFTS", params
|
||||||
|
)
|
||||||
|
if error or not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for draft in data.get("drafts", []):
|
||||||
|
if draft.get("message", {}).get("id") == message_id:
|
||||||
|
return draft.get("id")
|
||||||
|
|
||||||
|
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||||
|
if not page_token:
|
||||||
|
return None
|
||||||
|
|
||||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.utils.document_converters import (
|
from app.utils.document_converters import (
|
||||||
create_document_chunks,
|
create_document_chunks,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
|
@ -21,7 +22,6 @@ from app.utils.document_converters import (
|
||||||
generate_document_summary,
|
generate_document_summary,
|
||||||
generate_unique_identifier_hash,
|
generate_unique_identifier_hash,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -203,15 +203,38 @@ class GoogleCalendarKBSyncService:
|
||||||
logger.warning("Document %s not found in KB", document_id)
|
logger.warning("Document %s not found in KB", document_id)
|
||||||
return {"status": "not_indexed"}
|
return {"status": "not_indexed"}
|
||||||
|
|
||||||
|
calendar_id = (document.document_metadata or {}).get(
|
||||||
|
"calendar_id"
|
||||||
|
) or "primary"
|
||||||
|
connector = await self._get_connector(connector_id)
|
||||||
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_EVENTS_GET",
|
||||||
|
params={"calendar_id": calendar_id, "event_id": event_id},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get("error", "Unknown Composio Calendar error")
|
||||||
|
)
|
||||||
|
live_event = composio_result.get("data", {})
|
||||||
|
if isinstance(live_event, dict):
|
||||||
|
live_event = live_event.get("data", live_event)
|
||||||
|
if isinstance(live_event, dict):
|
||||||
|
live_event = live_event.get("response_data", live_event)
|
||||||
|
else:
|
||||||
creds = await self._build_credentials_for_connector(connector_id)
|
creds = await self._build_credentials_for_connector(connector_id)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
service = await loop.run_in_executor(
|
service = await loop.run_in_executor(
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
)
|
)
|
||||||
|
|
||||||
calendar_id = (document.document_metadata or {}).get(
|
|
||||||
"calendar_id"
|
|
||||||
) or "primary"
|
|
||||||
live_event = await loop.run_in_executor(
|
live_event = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
|
||||||
await self.db_session.rollback()
|
await self.db_session.rollback()
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
|
||||||
result = await self.db_session.execute(
|
result = await self.db_session.execute(
|
||||||
select(SearchSourceConnector).where(
|
select(SearchSourceConnector).where(
|
||||||
SearchSourceConnector.id == connector_id
|
SearchSourceConnector.id == connector_id
|
||||||
|
|
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
|
||||||
connector = result.scalar_one_or_none()
|
connector = result.scalar_one_or_none()
|
||||||
if not connector:
|
if not connector:
|
||||||
raise ValueError(f"Connector {connector_id} not found")
|
raise ValueError(f"Connector {connector_id} not found")
|
||||||
|
return connector
|
||||||
|
|
||||||
|
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||||
|
connector = await self._get_connector(connector_id)
|
||||||
if (
|
if (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
):
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
raise ValueError(
|
||||||
if cca_id:
|
"Composio Calendar connectors must use Composio tool execution"
|
||||||
return build_composio_credentials(cca_id)
|
)
|
||||||
raise ValueError("Composio connected_account_id not found")
|
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
if (
|
return (
|
||||||
connector.connector_type
|
connector.connector_type
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
)
|
||||||
|
|
||||||
|
def _get_composio_connected_account_id(
|
||||||
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
if cca_id:
|
if not cca_id:
|
||||||
return build_composio_credentials(cca_id)
|
|
||||||
raise ValueError("Composio connected_account_id not found")
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
async def _execute_composio_calendar_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict,
|
||||||
|
) -> tuple[dict | list | None, str | None]:
|
||||||
|
service = ComposioService()
|
||||||
|
result = await service.execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Calendar error")
|
||||||
|
|
||||||
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner), None
|
||||||
|
return inner, None
|
||||||
|
return data, None
|
||||||
|
|
||||||
|
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
raise ValueError(
|
||||||
|
"Composio Calendar connectors must use Composio tool execution"
|
||||||
|
)
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
|
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
_data, error = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_GET_CALENDAR",
|
||||||
|
{"calendar_id": "primary"},
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
|
|
@ -255,6 +297,23 @@ class GoogleCalendarToolMetadataService:
|
||||||
timezone_str = ""
|
timezone_str = ""
|
||||||
if connector:
|
if connector:
|
||||||
try:
|
try:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
cal_list, cal_error = await self._execute_composio_calendar_tool(
|
||||||
|
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
|
||||||
|
)
|
||||||
|
if cal_error:
|
||||||
|
raise RuntimeError(cal_error)
|
||||||
|
(
|
||||||
|
settings,
|
||||||
|
settings_error,
|
||||||
|
) = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_SETTINGS_GET",
|
||||||
|
{"setting": "timezone"},
|
||||||
|
)
|
||||||
|
if not settings_error and isinstance(settings, dict):
|
||||||
|
timezone_str = settings.get("value", "")
|
||||||
|
else:
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
service = await loop.run_in_executor(
|
service = await loop.run_in_executor(
|
||||||
|
|
@ -264,7 +323,22 @@ class GoogleCalendarToolMetadataService:
|
||||||
cal_list = await loop.run_in_executor(
|
cal_list = await loop.run_in_executor(
|
||||||
None, lambda: service.calendarList().list().execute()
|
None, lambda: service.calendarList().list().execute()
|
||||||
)
|
)
|
||||||
for cal in cal_list.get("items", []):
|
|
||||||
|
tz_setting = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: service.settings().get(setting="timezone").execute(),
|
||||||
|
)
|
||||||
|
timezone_str = tz_setting.get("value", "")
|
||||||
|
|
||||||
|
calendar_items = []
|
||||||
|
if isinstance(cal_list, dict):
|
||||||
|
calendar_items = (
|
||||||
|
cal_list.get("items") or cal_list.get("calendars") or []
|
||||||
|
)
|
||||||
|
elif isinstance(cal_list, list):
|
||||||
|
calendar_items = cal_list
|
||||||
|
|
||||||
|
for cal in calendar_items:
|
||||||
calendars.append(
|
calendars.append(
|
||||||
{
|
{
|
||||||
"id": cal.get("id", ""),
|
"id": cal.get("id", ""),
|
||||||
|
|
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
|
||||||
"primary": cal.get("primary", False),
|
"primary": cal.get("primary", False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
tz_setting = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: service.settings().get(setting="timezone").execute(),
|
|
||||||
)
|
|
||||||
timezone_str = tz_setting.get("value", "")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to fetch calendars/timezone for connector %s",
|
"Failed to fetch calendars/timezone for connector %s",
|
||||||
|
|
@ -321,12 +389,21 @@ class GoogleCalendarToolMetadataService:
|
||||||
|
|
||||||
event_dict = event.to_dict()
|
event_dict = event.to_dict()
|
||||||
try:
|
try:
|
||||||
|
calendar_id = event.calendar_id or "primary"
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
live_event, error = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_EVENTS_GET",
|
||||||
|
{"calendar_id": calendar_id, "event_id": event.event_id},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
creds = await self._build_credentials(connector)
|
creds = await self._build_credentials(connector)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
service = await loop.run_in_executor(
|
service = await loop.run_in_executor(
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
)
|
)
|
||||||
calendar_id = event.calendar_id or "primary"
|
|
||||||
live_event = await loop.run_in_executor(
|
live_event = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
@ -376,14 +453,32 @@ class GoogleCalendarToolMetadataService:
|
||||||
) -> dict:
|
) -> dict:
|
||||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
|
live_resolved = await self._resolve_live_event(
|
||||||
|
search_space_id, user_id, event_ref
|
||||||
|
)
|
||||||
|
if not live_resolved:
|
||||||
return {
|
return {
|
||||||
"error": (
|
"error": (
|
||||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
|
||||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
"This could mean: (1) the event doesn't exist, "
|
||||||
"or (3) the event name is different."
|
"(2) the event name is different, or "
|
||||||
|
"(3) the connected calendar account cannot access it."
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connector, live_event = live_resolved
|
||||||
|
account = GoogleCalendarAccount.from_connector(connector)
|
||||||
|
acc_dict = account.to_dict()
|
||||||
|
auth_expired = await self._check_account_health(connector.id)
|
||||||
|
acc_dict["auth_expired"] = auth_expired
|
||||||
|
if auth_expired:
|
||||||
|
await self._persist_auth_expired(connector.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"account": acc_dict,
|
||||||
|
"event": self._event_dict_from_live_event(live_event),
|
||||||
|
}
|
||||||
|
|
||||||
document, connector = resolved
|
document, connector = resolved
|
||||||
account = GoogleCalendarAccount.from_connector(connector)
|
account = GoogleCalendarAccount.from_connector(connector)
|
||||||
event = GoogleCalendarEvent.from_document(document)
|
event = GoogleCalendarEvent.from_document(document)
|
||||||
|
|
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
|
||||||
if row:
|
if row:
|
||||||
return row[0], row[1]
|
return row[0], row[1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _resolve_live_event(
|
||||||
|
self, search_space_id: int, user_id: str, event_ref: str
|
||||||
|
) -> tuple[SearchSourceConnector, dict] | None:
|
||||||
|
result = await self._db_session.execute(
|
||||||
|
select(SearchSourceConnector)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||||
|
)
|
||||||
|
connectors = result.scalars().all()
|
||||||
|
|
||||||
|
for connector in connectors:
|
||||||
|
try:
|
||||||
|
events = await self._search_live_events(connector, event_ref)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to search live calendar events for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_ref = event_ref.strip().lower()
|
||||||
|
exact_match = next(
|
||||||
|
(
|
||||||
|
event
|
||||||
|
for event in events
|
||||||
|
if event.get("summary", "").strip().lower() == normalized_ref
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return connector, exact_match or events[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _search_live_events(
|
||||||
|
self, connector: SearchSourceConnector, event_ref: str
|
||||||
|
) -> list[dict]:
|
||||||
|
if self._is_composio_connector(connector):
|
||||||
|
data, error = await self._execute_composio_calendar_tool(
|
||||||
|
connector,
|
||||||
|
"GOOGLECALENDAR_EVENTS_LIST",
|
||||||
|
{
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"q": event_ref,
|
||||||
|
"max_results": 10,
|
||||||
|
"single_events": True,
|
||||||
|
"order_by": "startTime",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data.get("items") or data.get("events") or []
|
||||||
|
return data if isinstance(data, list) else []
|
||||||
|
|
||||||
|
creds = await self._build_credentials(connector)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
service = await loop.run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.list(
|
||||||
|
calendarId="primary",
|
||||||
|
q=event_ref,
|
||||||
|
maxResults=10,
|
||||||
|
singleEvents=True,
|
||||||
|
orderBy="startTime",
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response.get("items", [])
|
||||||
|
|
||||||
|
def _event_dict_from_live_event(self, event: dict) -> dict:
|
||||||
|
start_data = event.get("start", {})
|
||||||
|
end_data = event.get("end", {})
|
||||||
|
return {
|
||||||
|
"event_id": event.get("id", ""),
|
||||||
|
"summary": event.get("summary", "No Title"),
|
||||||
|
"start": start_data.get("dateTime", start_data.get("date", "")),
|
||||||
|
"end": end_data.get("dateTime", end_data.get("date", "")),
|
||||||
|
"description": event.get("description", ""),
|
||||||
|
"location": event.get("location", ""),
|
||||||
|
"attendees": [
|
||||||
|
{
|
||||||
|
"email": attendee.get("email", ""),
|
||||||
|
"responseStatus": attendee.get("responseStatus", ""),
|
||||||
|
}
|
||||||
|
for attendee in event.get("attendees", [])
|
||||||
|
],
|
||||||
|
"calendar_id": event.get("calendarId", "primary"),
|
||||||
|
"document_id": None,
|
||||||
|
"indexed_at": None,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
|
||||||
def __init__(self, db_session: AsyncSession):
|
def __init__(self, db_session: AsyncSession):
|
||||||
self._db_session = db_session
|
self._db_session = db_session
|
||||||
|
|
||||||
|
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||||
|
return (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_composio_connected_account_id(
|
||||||
|
self, connector: SearchSourceConnector
|
||||||
|
) -> str:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
raise ValueError("Composio connected_account_id not found")
|
||||||
|
return cca_id
|
||||||
|
|
||||||
|
async def _execute_composio_drive_tool(
|
||||||
|
self,
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict,
|
||||||
|
) -> tuple[dict | list | None, str | None]:
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{connector.user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Drive error")
|
||||||
|
data = result.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner), None
|
||||||
|
return inner, None
|
||||||
|
return data, None
|
||||||
|
|
||||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||||
|
|
||||||
|
|
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
|
||||||
if not connector:
|
if not connector:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
pre_built_creds = None
|
if self._is_composio_connector(connector):
|
||||||
if (
|
_data, error = await self._execute_composio_drive_tool(
|
||||||
connector.connector_type
|
connector,
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
"GOOGLEDRIVE_LIST_FILES",
|
||||||
):
|
{
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
"q": "trashed = false",
|
||||||
if cca_id:
|
"page_size": 1,
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
"fields": "files(id)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return bool(error)
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=self._db_session,
|
session=self._db_session,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
await client.list_files(
|
await client.list_files(
|
||||||
query="trashed = false", page_size=1, fields="files(id)"
|
query="trashed = false", page_size=1, fields="files(id)"
|
||||||
|
|
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
|
||||||
parent_folders[connector_id] = []
|
parent_folders[connector_id] = []
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pre_built_creds = None
|
if self._is_composio_connector(connector):
|
||||||
if (
|
data, error = await self._execute_composio_drive_tool(
|
||||||
connector.connector_type
|
connector,
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
"GOOGLEDRIVE_LIST_FILES",
|
||||||
):
|
{
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
"q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||||
if cca_id:
|
"fields": "files(id,name)",
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
"page_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to list folders for connector %s: %s",
|
||||||
|
connector_id,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
parent_folders[connector_id] = []
|
||||||
|
continue
|
||||||
|
folders = []
|
||||||
|
if isinstance(data, dict):
|
||||||
|
folders = data.get("files", [])
|
||||||
|
elif isinstance(data, list):
|
||||||
|
folders = data
|
||||||
|
parent_folders[connector_id] = [
|
||||||
|
{"folder_id": f["id"], "name": f["name"]}
|
||||||
|
for f in folders
|
||||||
|
if f.get("id") and f.get("name")
|
||||||
|
]
|
||||||
|
continue
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
client = GoogleDriveClient(
|
||||||
session=self._db_session,
|
session=self._db_session,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
folders, _, error = await client.list_files(
|
folders, _, error = await client.list_files(
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ from typing import Any
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.utils import ImageResponse
|
from litellm.utils import ImageResponse
|
||||||
|
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Special ID for Auto mode - uses router for load balancing
|
# Special ID for Auto mode - uses router for load balancing
|
||||||
|
|
@ -152,10 +154,10 @@ class ImageGenRouterService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build model string
|
# Build model string
|
||||||
if config.get("custom_provider"):
|
|
||||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
|
||||||
else:
|
|
||||||
provider = config.get("provider", "").upper()
|
provider = config.get("provider", "").upper()
|
||||||
|
if config.get("custom_provider"):
|
||||||
|
provider_prefix = config["custom_provider"]
|
||||||
|
else:
|
||||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||||
|
|
||||||
|
|
@ -165,9 +167,16 @@ class ImageGenRouterService:
|
||||||
"api_key": config.get("api_key"),
|
"api_key": config.get("api_key"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional api_base
|
# Resolve ``api_base`` so deployments don't silently inherit
|
||||||
if config.get("api_base"):
|
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||||
litellm_params["api_base"] = config["api_base"]
|
# the wrong provider (see ``provider_api_base`` docstring).
|
||||||
|
api_base = resolve_api_base(
|
||||||
|
provider=provider,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=config.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
litellm_params["api_base"] = api_base
|
||||||
|
|
||||||
# Add api_version (required for Azure)
|
# Add api_version (required for Azure)
|
||||||
if config.get("api_version"):
|
if config.get("api_version"):
|
||||||
|
|
|
||||||
|
|
@ -140,8 +140,6 @@ PROVIDER_MAP = {
|
||||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||||
# backward compatibility with any external import.
|
# backward compatibility with any external import.
|
||||||
from app.services.provider_api_base import ( # noqa: E402
|
from app.services.provider_api_base import ( # noqa: E402
|
||||||
PROVIDER_DEFAULT_API_BASE,
|
|
||||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
|
||||||
resolve_api_base,
|
resolve_api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from app.services.llm_router_service import (
|
||||||
get_auto_mode_llm,
|
get_auto_mode_llm,
|
||||||
is_auto_mode,
|
is_auto_mode,
|
||||||
)
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.services.token_tracking_service import token_tracker
|
from app.services.token_tracking_service import token_tracker
|
||||||
|
|
||||||
# Configure litellm to automatically drop unsupported parameters
|
# Configure litellm to automatically drop unsupported parameters
|
||||||
|
|
@ -556,22 +557,26 @@ async def get_vision_llm(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if global_cfg.get("custom_provider"):
|
if global_cfg.get("custom_provider"):
|
||||||
model_string = (
|
provider_prefix = global_cfg["custom_provider"]
|
||||||
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
|
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prefix = VISION_PROVIDER_MAP.get(
|
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||||
global_cfg["provider"].upper(),
|
global_cfg["provider"].upper(),
|
||||||
global_cfg["provider"].lower(),
|
global_cfg["provider"].lower(),
|
||||||
)
|
)
|
||||||
model_string = f"{prefix}/{global_cfg['model_name']}"
|
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||||
|
|
||||||
litellm_kwargs = {
|
litellm_kwargs = {
|
||||||
"model": model_string,
|
"model": model_string,
|
||||||
"api_key": global_cfg["api_key"],
|
"api_key": global_cfg["api_key"],
|
||||||
}
|
}
|
||||||
if global_cfg.get("api_base"):
|
api_base = resolve_api_base(
|
||||||
litellm_kwargs["api_base"] = global_cfg["api_base"]
|
provider=global_cfg.get("provider"),
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=global_cfg.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
litellm_kwargs["api_base"] = api_base
|
||||||
if global_cfg.get("litellm_params"):
|
if global_cfg.get("litellm_params"):
|
||||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||||
|
|
||||||
|
|
@ -606,20 +611,26 @@ async def get_vision_llm(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if vision_cfg.custom_provider:
|
if vision_cfg.custom_provider:
|
||||||
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
|
provider_prefix = vision_cfg.custom_provider
|
||||||
|
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||||
else:
|
else:
|
||||||
prefix = VISION_PROVIDER_MAP.get(
|
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||||
vision_cfg.provider.value.upper(),
|
vision_cfg.provider.value.upper(),
|
||||||
vision_cfg.provider.value.lower(),
|
vision_cfg.provider.value.lower(),
|
||||||
)
|
)
|
||||||
model_string = f"{prefix}/{vision_cfg.model_name}"
|
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||||
|
|
||||||
litellm_kwargs = {
|
litellm_kwargs = {
|
||||||
"model": model_string,
|
"model": model_string,
|
||||||
"api_key": vision_cfg.api_key,
|
"api_key": vision_cfg.api_key,
|
||||||
}
|
}
|
||||||
if vision_cfg.api_base:
|
api_base = resolve_api_base(
|
||||||
litellm_kwargs["api_base"] = vision_cfg.api_base
|
provider=vision_cfg.provider.value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=vision_cfg.api_base,
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
litellm_kwargs["api_base"] = api_base
|
||||||
if vision_cfg.litellm_params:
|
if vision_cfg.litellm_params:
|
||||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool:
|
||||||
return "image" in input_mods and "text" in output_mods
|
return "image" in input_mods and "text" in output_mods
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_image_input(model: dict) -> bool:
|
||||||
|
"""Return True if the model accepts ``image`` in its input modalities.
|
||||||
|
|
||||||
|
Differs from :func:`_is_vision_input_model` in that it does NOT
|
||||||
|
require text output — chat-tab models always emit text already (the
|
||||||
|
chat catalog filters by ``_is_text_output_model``), so the only
|
||||||
|
extra capability we need to track per chat config is whether the
|
||||||
|
model can ingest user-attached images. The chat selector and the
|
||||||
|
streaming task both key off this flag to prevent hitting an
|
||||||
|
OpenRouter 404 ``"No endpoints found that support image input"``
|
||||||
|
when the user uploads an image and selects a text-only model
|
||||||
|
(DeepSeek V3, Llama 3.x base, etc.).
|
||||||
|
"""
|
||||||
|
arch = model.get("architecture", {}) or {}
|
||||||
|
input_mods = arch.get("input_modalities", []) or []
|
||||||
|
return "image" in input_mods
|
||||||
|
|
||||||
|
|
||||||
def _supports_tool_calling(model: dict) -> bool:
|
def _supports_tool_calling(model: dict) -> bool:
|
||||||
"""Return True if the model supports function/tool calling."""
|
"""Return True if the model supports function/tool calling."""
|
||||||
supported = model.get("supported_parameters") or []
|
supported = model.get("supported_parameters") or []
|
||||||
|
|
@ -321,6 +339,13 @@ def _generate_configs(
|
||||||
# account-wide quota, so per-deployment routing can't spread load
|
# account-wide quota, so per-deployment routing can't spread load
|
||||||
# there — it just drains the shared bucket faster.
|
# there — it just drains the shared bucket faster.
|
||||||
"router_pool_eligible": tier == "premium",
|
"router_pool_eligible": tier == "premium",
|
||||||
|
# Capability flag derived from ``architecture.input_modalities``.
|
||||||
|
# Read by the new-chat selector to dim image-incompatible models
|
||||||
|
# when the user has pending image attachments, and by
|
||||||
|
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||||
|
# OpenRouter request would otherwise 404 with
|
||||||
|
# ``"No endpoints found that support image input"``.
|
||||||
|
"supports_image_input": _supports_image_input(model),
|
||||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||||
# to the static score and gets re-blended with health on the next
|
# to the static score and gets re-blended with health on the next
|
||||||
|
|
@ -398,7 +423,12 @@ def _generate_image_gen_configs(
|
||||||
"provider": "OPENROUTER",
|
"provider": "OPENROUTER",
|
||||||
"model_name": model_id,
|
"model_name": model_id,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"api_base": "",
|
# Pin to OpenRouter's public base URL so a downstream call site
|
||||||
|
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||||
|
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
|
||||||
|
# ``image_generation/transformation`` (defense-in-depth, see
|
||||||
|
# ``provider_api_base`` docstring).
|
||||||
|
"api_base": "https://openrouter.ai/api/v1",
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"rpm": free_rpm if tier == "free" else rpm,
|
"rpm": free_rpm if tier == "free" else rpm,
|
||||||
"litellm_params": dict(litellm_params),
|
"litellm_params": dict(litellm_params),
|
||||||
|
|
@ -477,7 +507,11 @@ def _generate_vision_llm_configs(
|
||||||
"provider": "OPENROUTER",
|
"provider": "OPENROUTER",
|
||||||
"model_name": model_id,
|
"model_name": model_id,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"api_base": "",
|
# Pin to OpenRouter's public base URL so a downstream call site
|
||||||
|
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||||
|
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
|
||||||
|
# ``provider_api_base`` docstring).
|
||||||
|
"api_base": "https://openrouter.ai/api/v1",
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"rpm": free_rpm if tier == "free" else rpm,
|
"rpm": free_rpm if tier == "free" else rpm,
|
||||||
"tpm": free_tpm if tier == "free" else tpm,
|
"tpm": free_tpm if tier == "free" else tpm,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ source of truth without an inter-service circular import.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||||
"openrouter": "https://openrouter.ai/api/v1",
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
"groq": "https://api.groq.com/openai/v1",
|
"groq": "https://api.groq.com/openai/v1",
|
||||||
|
|
|
||||||
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
"""Capability resolution shared by chat / image / vision call sites.
|
||||||
|
|
||||||
|
Why this exists
|
||||||
|
---------------
|
||||||
|
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
|
||||||
|
single, authoritative answer to one question: *can this chat config accept
|
||||||
|
``image_url`` content blocks?* Without it, the new-chat selector can't badge
|
||||||
|
incompatible models and the streaming task can't fail fast with a friendly
|
||||||
|
error before sending an image to a text-only provider.
|
||||||
|
|
||||||
|
Two functions, two intents:
|
||||||
|
|
||||||
|
- :func:`derive_supports_image_input` — best-effort *True* for catalog and
|
||||||
|
UI surfacing. Default-allow: an unknown / unmapped model is treated as
|
||||||
|
capable so we never lock the user out of a freshly added or
|
||||||
|
third-party-hosted vision model.
|
||||||
|
|
||||||
|
- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
|
||||||
|
task's safety net. Returns True only when LiteLLM's model map *explicitly*
|
||||||
|
sets ``supports_vision=False`` (or its bare-name variant does). Anything
|
||||||
|
else — missing key, lookup exception, ``supports_vision=True`` — returns
|
||||||
|
False so the request flows through to the provider.
|
||||||
|
|
||||||
|
Implementation rule: only public LiteLLM symbols
|
||||||
|
------------------------------------------------
|
||||||
|
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
|
||||||
|
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
|
||||||
|
across releases. The private ``_is_explicitly_disabled_factory`` and
|
||||||
|
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
|
||||||
|
can't silently break us.
|
||||||
|
|
||||||
|
Why the previous round's strict YAML opt-in flag failed
|
||||||
|
-------------------------------------------------------
|
||||||
|
``supports_image_input: false`` was the YAML loader's setdefault. Operators
|
||||||
|
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
|
||||||
|
YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
|
||||||
|
False and the streaming gate rejected every image turn. Sourcing capability
|
||||||
|
from LiteLLM's authoritative model map (which already says
|
||||||
|
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Provider-name → LiteLLM model-prefix map.
|
||||||
|
#
|
||||||
|
# Owned here because ``app.services.provider_capabilities`` is the
|
||||||
|
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||||
|
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
|
||||||
|
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||||
|
# map there directly would re-introduce the
|
||||||
|
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
||||||
|
# app.config`` cycle that prompted the move.
|
||||||
|
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||||
|
"OPENAI": "openai",
|
||||||
|
"ANTHROPIC": "anthropic",
|
||||||
|
"GROQ": "groq",
|
||||||
|
"COHERE": "cohere",
|
||||||
|
"GOOGLE": "gemini",
|
||||||
|
"OLLAMA": "ollama_chat",
|
||||||
|
"MISTRAL": "mistral",
|
||||||
|
"AZURE_OPENAI": "azure",
|
||||||
|
"OPENROUTER": "openrouter",
|
||||||
|
"XAI": "xai",
|
||||||
|
"BEDROCK": "bedrock",
|
||||||
|
"VERTEX_AI": "vertex_ai",
|
||||||
|
"TOGETHER_AI": "together_ai",
|
||||||
|
"FIREWORKS_AI": "fireworks_ai",
|
||||||
|
"DEEPSEEK": "openai",
|
||||||
|
"ALIBABA_QWEN": "openai",
|
||||||
|
"MOONSHOT": "openai",
|
||||||
|
"ZHIPU": "openai",
|
||||||
|
"GITHUB_MODELS": "github",
|
||||||
|
"REPLICATE": "replicate",
|
||||||
|
"PERPLEXITY": "perplexity",
|
||||||
|
"ANYSCALE": "anyscale",
|
||||||
|
"DEEPINFRA": "deepinfra",
|
||||||
|
"CEREBRAS": "cerebras",
|
||||||
|
"SAMBANOVA": "sambanova",
|
||||||
|
"AI21": "ai21",
|
||||||
|
"CLOUDFLARE": "cloudflare",
|
||||||
|
"DATABRICKS": "databricks",
|
||||||
|
"COMETAPI": "cometapi",
|
||||||
|
"HUGGINGFACE": "huggingface",
|
||||||
|
"MINIMAX": "openai",
|
||||||
|
"CUSTOM": "custom",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_model_strings(
|
||||||
|
*,
|
||||||
|
provider: str | None,
|
||||||
|
model_name: str | None,
|
||||||
|
base_model: str | None,
|
||||||
|
custom_provider: str | None,
|
||||||
|
) -> list[tuple[str, str | None]]:
|
||||||
|
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
|
||||||
|
|
||||||
|
LiteLLM's capability lookup is keyed by ``model`` + (optional)
|
||||||
|
``custom_llm_provider``. Different config sources give us different
|
||||||
|
levels of detail, so we try the most-specific keys first and fall back
|
||||||
|
to bare model names so unannotated entries (e.g. an Azure deployment
|
||||||
|
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
|
||||||
|
map. Order matters — the first lookup that returns a definitive answer
|
||||||
|
wins for both helpers.
|
||||||
|
"""
|
||||||
|
candidates: list[tuple[str, str | None]] = []
|
||||||
|
seen: set[tuple[str, str | None]] = set()
|
||||||
|
|
||||||
|
def _add(model: str | None, llm_provider: str | None) -> None:
|
||||||
|
if not model:
|
||||||
|
return
|
||||||
|
key = (model, llm_provider)
|
||||||
|
if key in seen:
|
||||||
|
return
|
||||||
|
seen.add(key)
|
||||||
|
candidates.append(key)
|
||||||
|
|
||||||
|
provider_prefix: str | None = None
|
||||||
|
if provider:
|
||||||
|
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
|
||||||
|
if custom_provider:
|
||||||
|
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
|
||||||
|
provider_prefix = custom_provider
|
||||||
|
|
||||||
|
primary_model = base_model or model_name
|
||||||
|
bare_model = model_name
|
||||||
|
|
||||||
|
# Most-specific first: provider-prefixed identifier with explicit
|
||||||
|
# custom_llm_provider so LiteLLM won't have to guess the provider via
|
||||||
|
# ``get_llm_provider``.
|
||||||
|
if primary_model and provider_prefix:
|
||||||
|
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
|
||||||
|
if "/" in primary_model:
|
||||||
|
_add(primary_model, provider_prefix)
|
||||||
|
else:
|
||||||
|
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
|
||||||
|
|
||||||
|
# Bare base_model (or model_name) with provider hint — handles entries
|
||||||
|
# the upstream map keys without a provider prefix (most ``gpt-*`` and
|
||||||
|
# ``claude-*`` entries do this).
|
||||||
|
if primary_model:
|
||||||
|
_add(primary_model, provider_prefix)
|
||||||
|
|
||||||
|
# Fallback to model_name when base_model differs (e.g. an Azure
|
||||||
|
# deployment whose model_name is the deployment id but base_model is the
|
||||||
|
# canonical OpenAI sku).
|
||||||
|
if bare_model and bare_model != primary_model:
|
||||||
|
if provider_prefix and "/" not in bare_model:
|
||||||
|
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
|
||||||
|
_add(bare_model, provider_prefix)
|
||||||
|
_add(bare_model, None)
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
def derive_supports_image_input(
|
||||||
|
*,
|
||||||
|
provider: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
custom_provider: str | None = None,
|
||||||
|
openrouter_input_modalities: Iterable[str] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Best-effort capability flag for the new-chat selector and catalog.
|
||||||
|
|
||||||
|
Resolution order (first definitive answer wins):
|
||||||
|
|
||||||
|
1. ``openrouter_input_modalities`` (when provided as a non-empty
|
||||||
|
iterable). OpenRouter exposes ``architecture.input_modalities`` per
|
||||||
|
model and that's the authoritative source for OR dynamic configs.
|
||||||
|
2. ``litellm.supports_vision`` against each candidate identifier from
|
||||||
|
:func:`_candidate_model_strings`. Returns True as soon as any
|
||||||
|
candidate confirms vision support.
|
||||||
|
3. Default ``True`` — the conservative-allow stance. An unknown /
|
||||||
|
newly-added / third-party-hosted model is *not* pre-judged. The
|
||||||
|
streaming safety net (:func:`is_known_text_only_chat_model`) is the
|
||||||
|
only place a False ever blocks; everywhere else, a False here would
|
||||||
|
just hide a usable model from the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model can plausibly accept image input, False only when
|
||||||
|
OpenRouter explicitly says it can't.
|
||||||
|
"""
|
||||||
|
if openrouter_input_modalities is not None:
|
||||||
|
modalities = list(openrouter_input_modalities)
|
||||||
|
if modalities:
|
||||||
|
return "image" in modalities
|
||||||
|
# Empty list explicitly published by OR — treat as "no image".
|
||||||
|
return False
|
||||||
|
|
||||||
|
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if litellm.supports_vision(
|
||||||
|
model=model_string, custom_llm_provider=custom_llm_provider
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"litellm.supports_vision raised for model=%s provider=%s: %s",
|
||||||
|
model_string,
|
||||||
|
custom_llm_provider,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_known_text_only_chat_model(
|
||||||
|
*,
|
||||||
|
provider: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
custom_provider: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Strict opt-out probe for the streaming-task safety net.
|
||||||
|
|
||||||
|
Returns True only when LiteLLM's model map *explicitly* sets
|
||||||
|
``supports_vision=False`` for at least one candidate identifier. Missing
|
||||||
|
key, lookup exception, or ``supports_vision=True`` all return False so
|
||||||
|
the streaming task lets the request through. This is the inverse-default
|
||||||
|
of :func:`derive_supports_image_input`.
|
||||||
|
|
||||||
|
Why two functions
|
||||||
|
-----------------
|
||||||
|
The selector wants "show me everything that's plausibly capable" —
|
||||||
|
default-allow. The safety net wants "block only when I'm certain it
|
||||||
|
can't" — default-pass. Mixing the two intents in a single function
|
||||||
|
leads to the regression we're fixing here.
|
||||||
|
"""
|
||||||
|
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
info = litellm.get_model_info(
|
||||||
|
model=model_string, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"litellm.get_model_info raised for model=%s provider=%s: %s",
|
||||||
|
model_string,
|
||||||
|
custom_llm_provider,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
|
||||||
|
# may be missing, None, True, or False. We only fire on explicit
|
||||||
|
# False — None / missing / True all mean "don't block".
|
||||||
|
try:
|
||||||
|
value = info.get("supports_vision") # type: ignore[union-attr]
|
||||||
|
except AttributeError:
|
||||||
|
value = None
|
||||||
|
if value is False:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"derive_supports_image_input",
|
||||||
|
"is_known_text_only_chat_model",
|
||||||
|
]
|
||||||
|
|
@ -1,10 +1,25 @@
|
||||||
"""Celery tasks package."""
|
"""Celery tasks package.
|
||||||
|
|
||||||
|
Also hosts the small helpers every async celery task should use to
|
||||||
|
spin up its event loop. See :func:`run_async_celery_task` for the
|
||||||
|
canonical pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.pool import NullPool
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_celery_engine = None
|
_celery_engine = None
|
||||||
_celery_session_maker = None
|
_celery_session_maker = None
|
||||||
|
|
||||||
|
|
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
|
||||||
_celery_engine, expire_on_commit=False
|
_celery_engine, expire_on_commit=False
|
||||||
)
|
)
|
||||||
return _celery_session_maker
|
return _celery_session_maker
|
||||||
|
|
||||||
|
|
||||||
|
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
"""Drop the shared ``app.db.engine`` connection pool synchronously.
|
||||||
|
|
||||||
|
The shared engine (used by ``shielded_async_session`` and most
|
||||||
|
routes / services) is a module-level singleton with a real pool.
|
||||||
|
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
|
||||||
|
connections cache a reference to whichever loop opened them. When
|
||||||
|
a subsequent task's loop pulls a stale connection from the pool,
|
||||||
|
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
|
||||||
|
|
||||||
|
AttributeError: 'NoneType' object has no attribute 'send'
|
||||||
|
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
|
||||||
|
self._write_fut = self._loop._proactor.send(self._sock, data)
|
||||||
|
|
||||||
|
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
|
||||||
|
coroutine that can never run because its loop is gone.
|
||||||
|
|
||||||
|
Disposing the engine forces the pool to drop every cached
|
||||||
|
connection so the next checkout opens a fresh one on the current
|
||||||
|
loop. Safe to call from a task's finally block; failure is logged
|
||||||
|
but never propagated.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.db import engine as shared_engine
|
||||||
|
|
||||||
|
loop.run_until_complete(shared_engine.dispose())
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Shared DB engine dispose() failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
|
||||||
|
"""Run an async coroutine inside a fresh event loop with proper
|
||||||
|
DB-engine cleanup.
|
||||||
|
|
||||||
|
This is the canonical entry point for every async celery task.
|
||||||
|
It performs three responsibilities that were previously copy-pasted
|
||||||
|
(incorrectly) across each task module:
|
||||||
|
|
||||||
|
1. Create a fresh ``asyncio`` loop and install it on the current
|
||||||
|
thread (celery's ``--pool=solo`` runs every task on the main
|
||||||
|
thread, but other pool types don't).
|
||||||
|
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
|
||||||
|
any stale connections left over from a previous task's loop
|
||||||
|
are dropped — defends against tasks that crashed without
|
||||||
|
cleaning up.
|
||||||
|
3. Dispose the shared engine AFTER the task runs so the
|
||||||
|
connections we opened on this loop are released before the
|
||||||
|
loop closes (avoids ``coroutine 'Connection._cancel' was
|
||||||
|
never awaited`` warnings and the next-task hang).
|
||||||
|
|
||||||
|
Use as::
|
||||||
|
|
||||||
|
@celery_app.task(name="my_task", bind=True)
|
||||||
|
def my_task(self, *args):
|
||||||
|
return run_async_celery_task(lambda: _my_task_impl(*args))
|
||||||
|
"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
# Defense-in-depth: prior task may have crashed before
|
||||||
|
# disposing. Idempotent — no-op if pool is already empty.
|
||||||
|
_dispose_shared_db_engine(loop)
|
||||||
|
return loop.run_until_complete(coro_factory())
|
||||||
|
finally:
|
||||||
|
# Drop any connections this task opened so they don't leak
|
||||||
|
# into the next task's loop.
|
||||||
|
_dispose_shared_db_engine(loop)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_celery_session_maker",
|
||||||
|
"run_async_celery_task",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -49,22 +49,15 @@ def index_notion_pages_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Notion pages."""
|
"""Celery task to index Notion pages."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
return run_async_celery_task(
|
||||||
_index_notion_pages(
|
lambda: _index_notion_pages(
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_notion_pages(
|
async def _index_notion_pages(
|
||||||
|
|
@ -95,19 +88,11 @@ def index_github_repos_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index GitHub repositories."""
|
"""Celery task to index GitHub repositories."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_github_repos(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_github_repos(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_github_repos(
|
async def _index_github_repos(
|
||||||
|
|
@ -138,19 +123,11 @@ def index_confluence_pages_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Confluence pages."""
|
"""Celery task to index Confluence pages."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_confluence_pages(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_confluence_pages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_confluence_pages(
|
async def _index_confluence_pages(
|
||||||
|
|
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Google Calendar events."""
|
"""Celery task to index Google Calendar events."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
return run_async_celery_task(
|
||||||
_index_google_calendar_events(
|
lambda: _index_google_calendar_events(
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_google_calendar_events(
|
async def _index_google_calendar_events(
|
||||||
|
|
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Google Gmail messages."""
|
"""Celery task to index Google Gmail messages."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_google_gmail_messages(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_google_gmail_messages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_google_gmail_messages(
|
async def _index_google_gmail_messages(
|
||||||
|
|
@ -269,22 +231,14 @@ def index_google_drive_files_task(
|
||||||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||||
):
|
):
|
||||||
"""Celery task to index Google Drive folders and files."""
|
"""Celery task to index Google Drive folders and files."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_google_drive_files(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_google_drive_files(
|
|
||||||
connector_id,
|
connector_id,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
user_id,
|
user_id,
|
||||||
items_dict,
|
items_dict,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_google_drive_files(
|
async def _index_google_drive_files(
|
||||||
|
|
@ -317,22 +271,14 @@ def index_onedrive_files_task(
|
||||||
items_dict: dict,
|
items_dict: dict,
|
||||||
):
|
):
|
||||||
"""Celery task to index OneDrive folders and files."""
|
"""Celery task to index OneDrive folders and files."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_onedrive_files(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_onedrive_files(
|
|
||||||
connector_id,
|
connector_id,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
user_id,
|
user_id,
|
||||||
items_dict,
|
items_dict,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_onedrive_files(
|
async def _index_onedrive_files(
|
||||||
|
|
@ -365,22 +311,14 @@ def index_dropbox_files_task(
|
||||||
items_dict: dict,
|
items_dict: dict,
|
||||||
):
|
):
|
||||||
"""Celery task to index Dropbox folders and files."""
|
"""Celery task to index Dropbox folders and files."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_dropbox_files(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_dropbox_files(
|
|
||||||
connector_id,
|
connector_id,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
user_id,
|
user_id,
|
||||||
items_dict,
|
items_dict,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_dropbox_files(
|
async def _index_dropbox_files(
|
||||||
|
|
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Elasticsearch documents."""
|
"""Celery task to index Elasticsearch documents."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_elasticsearch_documents(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_elasticsearch_documents(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_elasticsearch_documents(
|
async def _index_elasticsearch_documents(
|
||||||
|
|
@ -457,22 +387,15 @@ def index_crawled_urls_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index Web page Urls."""
|
"""Celery task to index Web page Urls."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
return run_async_celery_task(
|
||||||
_index_crawled_urls(
|
lambda: _index_crawled_urls(
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_crawled_urls(
|
async def _index_crawled_urls(
|
||||||
|
|
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
):
|
):
|
||||||
"""Celery task to index BookStack pages."""
|
"""Celery task to index BookStack pages."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_bookstack_pages(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_bookstack_pages(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_bookstack_pages(
|
async def _index_bookstack_pages(
|
||||||
|
|
@ -546,19 +461,11 @@ def index_composio_connector_task(
|
||||||
end_date: str | None,
|
end_date: str | None,
|
||||||
):
|
):
|
||||||
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
|
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
|
||||||
import asyncio
|
return run_async_celery_task(
|
||||||
|
lambda: _index_composio_connector(
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_composio_connector(
|
|
||||||
connector_id, search_space_id, user_id, start_date, end_date
|
connector_id, search_space_id, user_id, start_date, end_date
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_composio_connector(
|
async def _index_composio_connector(
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from app.db import Document
|
||||||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
|
||||||
document_id: ID of document to reindex
|
document_id: ID of document to reindex
|
||||||
user_id: ID of user who edited the document
|
user_id: ID of user who edited the document
|
||||||
"""
|
"""
|
||||||
import asyncio
|
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_reindex_document(document_id, user_id))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _reindex_document(document_id: int, user_id: str):
|
async def _reindex_document(document_id: int, user_id: str):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from app.celery_app import celery_app
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.services.notification_service import NotificationService
|
from app.services.notification_service import NotificationService
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
from app.tasks.connector_indexers.local_folder_indexer import (
|
from app.tasks.connector_indexers.local_folder_indexer import (
|
||||||
index_local_folder,
|
index_local_folder,
|
||||||
index_uploaded_files,
|
index_uploaded_files,
|
||||||
|
|
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
|
||||||
)
|
)
|
||||||
def delete_document_task(self, document_id: int):
|
def delete_document_task(self, document_id: int):
|
||||||
"""Celery task to delete a document and its chunks in batches."""
|
"""Celery task to delete a document and its chunks in batches."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(lambda: _delete_document_background(document_id))
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_delete_document_background(document_id))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _delete_document_background(document_id: int) -> None:
|
async def _delete_document_background(document_id: int) -> None:
|
||||||
|
|
@ -153,14 +148,9 @@ def delete_folder_documents_task(
|
||||||
folder_subtree_ids: list[int] | None = None,
|
folder_subtree_ids: list[int] | None = None,
|
||||||
):
|
):
|
||||||
"""Celery task to delete documents first, then the folder rows."""
|
"""Celery task to delete documents first, then the folder rows."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _delete_folder_documents(
|
async def _delete_folder_documents(
|
||||||
|
|
@ -209,12 +199,9 @@ async def _delete_folder_documents(
|
||||||
)
|
)
|
||||||
def delete_search_space_task(self, search_space_id: int):
|
def delete_search_space_task(self, search_space_id: int):
|
||||||
"""Celery task to delete a search space and heavy child rows in batches."""
|
"""Celery task to delete a search space and heavy child rows in batches."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _delete_search_space_background(search_space_id)
|
||||||
try:
|
)
|
||||||
loop.run_until_complete(_delete_search_space_background(search_space_id))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _delete_search_space_background(search_space_id: int) -> None:
|
async def _delete_search_space_background(search_space_id: int) -> None:
|
||||||
|
|
@ -269,18 +256,11 @@ def process_extension_document_task(
|
||||||
search_space_id: ID of the search space
|
search_space_id: ID of the search space
|
||||||
user_id: ID of the user
|
user_id: ID of the user
|
||||||
"""
|
"""
|
||||||
# Create a new event loop for this task
|
return run_async_celery_task(
|
||||||
loop = asyncio.new_event_loop()
|
lambda: _process_extension_document(
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_process_extension_document(
|
|
||||||
individual_document_dict, search_space_id, user_id
|
individual_document_dict, search_space_id, user_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_extension_document(
|
async def _process_extension_document(
|
||||||
|
|
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
|
||||||
search_space_id: ID of the search space
|
search_space_id: ID of the search space
|
||||||
user_id: ID of the user
|
user_id: ID of the user
|
||||||
"""
|
"""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _process_youtube_video(url, search_space_id, user_id)
|
||||||
|
)
|
||||||
try:
|
|
||||||
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
||||||
|
|
@ -573,12 +549,9 @@ def process_file_upload_task(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[process_file_upload] Could not get file size: {e}")
|
logger.warning(f"[process_file_upload] Could not get file size: {e}")
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
run_async_celery_task(
|
||||||
_process_file_upload(file_path, filename, search_space_id, user_id)
|
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[process_file_upload] Task completed successfully for: {filename}"
|
f"[process_file_upload] Task completed successfully for: {filename}"
|
||||||
|
|
@ -589,8 +562,6 @@ def process_file_upload_task(
|
||||||
f"Traceback:\n{traceback.format_exc()}"
|
f"Traceback:\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_file_upload(
|
async def _process_file_upload(
|
||||||
|
|
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
|
||||||
"File may have been removed before syncing could start."
|
"File may have been removed before syncing could start."
|
||||||
)
|
)
|
||||||
# Mark document as failed since file is missing
|
# Mark document as failed since file is missing
|
||||||
loop = asyncio.new_event_loop()
|
run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _mark_document_failed(
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_mark_document_failed(
|
|
||||||
document_id,
|
document_id,
|
||||||
"File not found. Please re-upload the file.",
|
"File not found. Please re-upload the file.",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
run_async_celery_task(
|
||||||
_process_file_with_document(
|
lambda: _process_file_with_document(
|
||||||
document_id,
|
document_id,
|
||||||
temp_path,
|
temp_path,
|
||||||
filename,
|
filename,
|
||||||
|
|
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
|
||||||
f"Traceback:\n{traceback.format_exc()}"
|
f"Traceback:\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _mark_document_failed(document_id: int, reason: str):
|
async def _mark_document_failed(document_id: int, reason: str):
|
||||||
|
|
@ -1119,12 +1080,8 @@ def process_circleback_meeting_task(
|
||||||
search_space_id: ID of the search space
|
search_space_id: ID of the search space
|
||||||
connector_id: ID of the Circleback connector (for deletion support)
|
connector_id: ID of the Circleback connector (for deletion support)
|
||||||
"""
|
"""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _process_circleback_meeting(
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_process_circleback_meeting(
|
|
||||||
meeting_id,
|
meeting_id,
|
||||||
meeting_name,
|
meeting_name,
|
||||||
markdown_content,
|
markdown_content,
|
||||||
|
|
@ -1133,8 +1090,6 @@ def process_circleback_meeting_task(
|
||||||
connector_id,
|
connector_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_circleback_meeting(
|
async def _process_circleback_meeting(
|
||||||
|
|
@ -1291,12 +1246,8 @@ def index_local_folder_task(
|
||||||
target_file_paths: list[str] | None = None,
|
target_file_paths: list[str] | None = None,
|
||||||
):
|
):
|
||||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _index_local_folder_async(
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_local_folder_async(
|
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
folder_path=folder_path,
|
folder_path=folder_path,
|
||||||
|
|
@ -1308,8 +1259,6 @@ def index_local_folder_task(
|
||||||
target_file_paths=target_file_paths,
|
target_file_paths=target_file_paths,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_local_folder_async(
|
async def _index_local_folder_async(
|
||||||
|
|
@ -1441,11 +1390,8 @@ def index_uploaded_folder_files_task(
|
||||||
processing_mode: str = "basic",
|
processing_mode: str = "basic",
|
||||||
):
|
):
|
||||||
"""Celery task to index files uploaded from the desktop app."""
|
"""Celery task to index files uploaded from the desktop app."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _index_uploaded_folder_files_async(
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_uploaded_folder_files_async(
|
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
folder_name=folder_name,
|
folder_name=folder_name,
|
||||||
|
|
@ -1456,8 +1402,6 @@ def index_uploaded_folder_files_task(
|
||||||
processing_mode=processing_mode,
|
processing_mode=processing_mode,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_uploaded_folder_files_async(
|
async def _index_uploaded_folder_files_async(
|
||||||
|
|
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
|
||||||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||||
"""Full AI sort for all documents in a search space."""
|
"""Full AI sort for all documents in a search space."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _ai_sort_search_space_async(search_space_id, user_id)
|
||||||
try:
|
)
|
||||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||||
|
|
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||||
)
|
)
|
||||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||||
"""Incremental AI sort for a single document after indexing."""
|
"""Incremental AI sort for a single document after indexing."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,13 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector
|
||||||
from app.schemas.obsidian_plugin import NotePayload
|
from app.schemas.obsidian_plugin import NotePayload
|
||||||
from app.services.obsidian_plugin_indexer import upsert_note
|
from app.services.obsidian_plugin_indexer import upsert_note
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process one Obsidian non-markdown attachment asynchronously."""
|
"""Process one Obsidian non-markdown attachment asynchronously."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(
|
||||||
asyncio.set_event_loop(loop)
|
lambda: _index_obsidian_attachment(
|
||||||
try:
|
|
||||||
loop.run_until_complete(
|
|
||||||
_index_obsidian_attachment(
|
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
payload_data=payload_data,
|
payload_data=payload_data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _index_obsidian_attachment(
|
async def _index_obsidian_attachment(
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -12,11 +13,12 @@ from app.celery_app import celery_app
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
from app.db import Podcast, PodcastStatus
|
from app.db import Podcast, PodcastStatus
|
||||||
from app.services.billable_calls import (
|
from app.services.billable_calls import (
|
||||||
|
BillingSettlementError,
|
||||||
QuotaInsufficientError,
|
QuotaInsufficientError,
|
||||||
_resolve_agent_billing_for_search_space,
|
_resolve_agent_billing_for_search_space,
|
||||||
billable_call,
|
billable_call,
|
||||||
)
|
)
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -34,6 +36,13 @@ if sys.platform.startswith("win"):
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _celery_billable_session():
|
||||||
|
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||||
|
async with get_celery_session_maker()() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||||
def generate_content_podcast_task(
|
def generate_content_podcast_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -46,27 +55,22 @@ def generate_content_podcast_task(
|
||||||
Celery task to generate podcast from source content.
|
Celery task to generate podcast from source content.
|
||||||
Updates existing podcast record created by the tool.
|
Updates existing podcast record created by the tool.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = loop.run_until_complete(
|
return run_async_celery_task(
|
||||||
_generate_content_podcast(
|
lambda: _generate_content_podcast(
|
||||||
podcast_id,
|
podcast_id,
|
||||||
source_content,
|
source_content,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
user_prompt,
|
user_prompt,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating content podcast: {e!s}")
|
logger.error(f"Error generating content podcast: {e!s}")
|
||||||
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
try:
|
||||||
|
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to mark podcast %s as failed", podcast_id)
|
||||||
return {"status": "failed", "podcast_id": podcast_id}
|
return {"status": "failed", "podcast_id": podcast_id}
|
||||||
finally:
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _mark_podcast_failed(podcast_id: int) -> None:
|
async def _mark_podcast_failed(podcast_id: int) -> None:
|
||||||
|
|
@ -148,11 +152,12 @@ async def _generate_content_podcast(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||||
usage_type="podcast_generation",
|
usage_type="podcast_generation",
|
||||||
thread_id=podcast.thread_id,
|
|
||||||
call_details={
|
call_details={
|
||||||
"podcast_id": podcast.id,
|
"podcast_id": podcast.id,
|
||||||
"title": podcast.title,
|
"title": podcast.title,
|
||||||
|
"thread_id": podcast.thread_id,
|
||||||
},
|
},
|
||||||
|
billable_session_factory=_celery_billable_session,
|
||||||
):
|
):
|
||||||
graph_result = await podcaster_graph.ainvoke(
|
graph_result = await podcaster_graph.ainvoke(
|
||||||
initial_state, config=graph_config
|
initial_state, config=graph_config
|
||||||
|
|
@ -173,6 +178,18 @@ async def _generate_content_podcast(
|
||||||
"podcast_id": podcast.id,
|
"podcast_id": podcast.id,
|
||||||
"reason": "premium_quota_exhausted",
|
"reason": "premium_quota_exhausted",
|
||||||
}
|
}
|
||||||
|
except BillingSettlementError:
|
||||||
|
logger.exception(
|
||||||
|
"Podcast %s: premium billing settlement failed",
|
||||||
|
podcast.id,
|
||||||
|
)
|
||||||
|
podcast.status = PodcastStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"podcast_id": podcast.id,
|
||||||
|
"reason": "billing_settlement_failed",
|
||||||
|
}
|
||||||
|
|
||||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||||
file_path = graph_result.get("final_podcast_file_path", "")
|
file_path = graph_result.get("final_podcast_file_path", "")
|
||||||
|
|
@ -194,7 +211,14 @@ async def _generate_content_podcast(
|
||||||
podcast.podcast_transcript = serializable_transcript
|
podcast.podcast_transcript = serializable_transcript
|
||||||
podcast.file_location = file_path
|
podcast.file_location = file_path
|
||||||
podcast.status = PodcastStatus.READY
|
podcast.status = PodcastStatus.READY
|
||||||
|
logger.info(
|
||||||
|
"Podcast %s: committing READY transcript_entries=%d file=%s",
|
||||||
|
podcast.id,
|
||||||
|
len(serializable_transcript),
|
||||||
|
file_path,
|
||||||
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
logger.info("Podcast %s: READY commit complete", podcast.id)
|
||||||
|
|
||||||
logger.info(f"Successfully generated podcast: {podcast.id}")
|
logger.info(f"Successfully generated podcast: {podcast.id}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
|
||||||
This task runs every minute and triggers indexing for any connector
|
This task runs every minute and triggers indexing for any connector
|
||||||
whose next_scheduled_at time has passed.
|
whose next_scheduled_at time has passed.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
return run_async_celery_task(_check_and_trigger_schedules)
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_check_and_trigger_schedules())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_and_trigger_schedules():
|
async def _check_and_trigger_schedules():
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from sqlalchemy.future import select
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import Document, DocumentStatus, Notification
|
from app.db import Document, DocumentStatus, Notification
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
|
||||||
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
|
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
|
||||||
Also marks associated pending/processing documents as failed.
|
Also marks associated pending/processing documents as failed.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
async def _both() -> None:
|
||||||
asyncio.set_event_loop(loop)
|
await _cleanup_stale_notifications()
|
||||||
|
await _cleanup_stale_document_processing_notifications()
|
||||||
|
|
||||||
try:
|
return run_async_celery_task(_both)
|
||||||
loop.run_until_complete(_cleanup_stale_notifications())
|
|
||||||
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _cleanup_stale_notifications():
|
async def _cleanup_stale_notifications():
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
|
@ -18,7 +17,7 @@ from app.db import (
|
||||||
PremiumTokenPurchaseStatus,
|
PremiumTokenPurchaseStatus,
|
||||||
)
|
)
|
||||||
from app.routes import stripe_routes
|
from app.routes import stripe_routes
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
|
||||||
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
|
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
|
||||||
def reconcile_pending_stripe_page_purchases_task():
|
def reconcile_pending_stripe_page_purchases_task():
|
||||||
"""Recover paid purchases that were left pending due to missed webhook handling."""
|
"""Recover paid purchases that were left pending due to missed webhook handling."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(_reconcile_pending_page_purchases)
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_reconcile_pending_page_purchases())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _reconcile_pending_page_purchases() -> None:
|
async def _reconcile_pending_page_purchases() -> None:
|
||||||
|
|
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
|
||||||
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
|
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
|
||||||
def reconcile_pending_stripe_token_purchases_task():
|
def reconcile_pending_stripe_token_purchases_task():
|
||||||
"""Recover paid token purchases that were left pending due to missed webhook handling."""
|
"""Recover paid token purchases that were left pending due to missed webhook handling."""
|
||||||
loop = asyncio.new_event_loop()
|
return run_async_celery_task(_reconcile_pending_token_purchases)
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_reconcile_pending_token_purchases())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _reconcile_pending_token_purchases() -> None:
|
async def _reconcile_pending_token_purchases() -> None:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -12,11 +13,12 @@ from app.celery_app import celery_app
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
from app.db import VideoPresentation, VideoPresentationStatus
|
from app.db import VideoPresentation, VideoPresentationStatus
|
||||||
from app.services.billable_calls import (
|
from app.services.billable_calls import (
|
||||||
|
BillingSettlementError,
|
||||||
QuotaInsufficientError,
|
QuotaInsufficientError,
|
||||||
_resolve_agent_billing_for_search_space,
|
_resolve_agent_billing_for_search_space,
|
||||||
billable_call,
|
billable_call,
|
||||||
)
|
)
|
||||||
from app.tasks.celery_tasks import get_celery_session_maker
|
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -29,6 +31,13 @@ if sys.platform.startswith("win"):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _celery_billable_session():
|
||||||
|
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||||
|
async with get_celery_session_maker()() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="generate_video_presentation", bind=True)
|
@celery_app.task(name="generate_video_presentation", bind=True)
|
||||||
def generate_video_presentation_task(
|
def generate_video_presentation_task(
|
||||||
self,
|
self,
|
||||||
|
|
@ -41,27 +50,30 @@ def generate_video_presentation_task(
|
||||||
Celery task to generate video presentation from source content.
|
Celery task to generate video presentation from source content.
|
||||||
Updates existing video presentation record created by the tool.
|
Updates existing video presentation record created by the tool.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = loop.run_until_complete(
|
return run_async_celery_task(
|
||||||
_generate_video_presentation(
|
lambda: _generate_video_presentation(
|
||||||
video_presentation_id,
|
video_presentation_id,
|
||||||
source_content,
|
source_content,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
user_prompt,
|
user_prompt,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating video presentation: {e!s}")
|
logger.error(f"Error generating video presentation: {e!s}")
|
||||||
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
|
# Mark FAILED in a fresh loop — the previous loop is closed.
|
||||||
|
# Swallow secondary failures; the row will simply stay in
|
||||||
|
# GENERATING and be flushed by the periodic stale cleanup.
|
||||||
|
try:
|
||||||
|
run_async_celery_task(
|
||||||
|
lambda: _mark_video_presentation_failed(video_presentation_id)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to mark video presentation %s as failed",
|
||||||
|
video_presentation_id,
|
||||||
|
)
|
||||||
return {"status": "failed", "video_presentation_id": video_presentation_id}
|
return {"status": "failed", "video_presentation_id": video_presentation_id}
|
||||||
finally:
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
|
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
|
||||||
|
|
@ -150,11 +162,12 @@ async def _generate_video_presentation(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||||
usage_type="video_presentation_generation",
|
usage_type="video_presentation_generation",
|
||||||
thread_id=video_pres.thread_id,
|
|
||||||
call_details={
|
call_details={
|
||||||
"video_presentation_id": video_pres.id,
|
"video_presentation_id": video_pres.id,
|
||||||
"title": video_pres.title,
|
"title": video_pres.title,
|
||||||
|
"thread_id": video_pres.thread_id,
|
||||||
},
|
},
|
||||||
|
billable_session_factory=_celery_billable_session,
|
||||||
):
|
):
|
||||||
graph_result = await video_presentation_graph.ainvoke(
|
graph_result = await video_presentation_graph.ainvoke(
|
||||||
initial_state, config=graph_config
|
initial_state, config=graph_config
|
||||||
|
|
@ -175,6 +188,18 @@ async def _generate_video_presentation(
|
||||||
"video_presentation_id": video_pres.id,
|
"video_presentation_id": video_pres.id,
|
||||||
"reason": "premium_quota_exhausted",
|
"reason": "premium_quota_exhausted",
|
||||||
}
|
}
|
||||||
|
except BillingSettlementError:
|
||||||
|
logger.exception(
|
||||||
|
"VideoPresentation %s: premium billing settlement failed",
|
||||||
|
video_pres.id,
|
||||||
|
)
|
||||||
|
video_pres.status = VideoPresentationStatus.FAILED
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"video_presentation_id": video_pres.id,
|
||||||
|
"reason": "billing_settlement_failed",
|
||||||
|
}
|
||||||
|
|
||||||
# Serialize slides (parsed content + audio info merged)
|
# Serialize slides (parsed content + audio info merged)
|
||||||
slides_raw = graph_result.get("slides", [])
|
slides_raw = graph_result.get("slides", [])
|
||||||
|
|
@ -205,7 +230,14 @@ async def _generate_video_presentation(
|
||||||
video_pres.slides = serializable_slides
|
video_pres.slides = serializable_slides
|
||||||
video_pres.scene_codes = serializable_scene_codes
|
video_pres.scene_codes = serializable_scene_codes
|
||||||
video_pres.status = VideoPresentationStatus.READY
|
video_pres.status = VideoPresentationStatus.READY
|
||||||
|
logger.info(
|
||||||
|
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
|
||||||
|
video_pres.id,
|
||||||
|
len(serializable_slides),
|
||||||
|
len(serializable_scene_codes),
|
||||||
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
|
||||||
|
|
||||||
logger.info(f"Successfully generated video presentation: {video_pres.id}")
|
logger.info(f"Successfully generated video presentation: {video_pres.id}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
|
|
@ -96,6 +97,47 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||||
|
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
|
||||||
|
|
||||||
|
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(candidate, dict):
|
||||||
|
value = candidate.get("value", candidate)
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
value = getattr(candidate, "value", None)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
if isinstance(candidate, (list, tuple)):
|
||||||
|
for item in candidate:
|
||||||
|
extracted = _extract_interrupt_value(item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
|
for task in getattr(state, "tasks", ()) or ():
|
||||||
|
try:
|
||||||
|
interrupts = getattr(task, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
interrupts = ()
|
||||||
|
if not interrupts:
|
||||||
|
extracted = _extract_interrupt_value(task)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
continue
|
||||||
|
for interrupt_item in interrupts:
|
||||||
|
extracted = _extract_interrupt_value(interrupt_item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
try:
|
||||||
|
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
state_interrupts = ()
|
||||||
|
extracted = _extract_interrupt_value(state_interrupts)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||||
|
|
||||||
|
|
@ -518,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
||||||
|
"""Wait for a discarded speculative agent build to release shared state.
|
||||||
|
|
||||||
|
Used by the parallel preflight + agent-build path. The speculative build
|
||||||
|
closes over the request-scoped ``AsyncSession`` (for the brief connector
|
||||||
|
discovery / tool-factory window before its CPU work moves into a worker
|
||||||
|
thread). If preflight reports a 429 we want to fall back to the original
|
||||||
|
repin → reload → rebuild path, but we MUST NOT touch ``session`` again
|
||||||
|
until any in-flight session work owned by the speculative build has
|
||||||
|
fully settled — :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
|
||||||
|
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
|
||||||
|
earlier in this PR (see ``connector_service`` parallel-gather revert).
|
||||||
|
|
||||||
|
We simply ``await`` the task and swallow any exception: in this path the
|
||||||
|
build's outcome is irrelevant — success populates the agent cache (a free
|
||||||
|
side effect), failure is discarded. The wasted CPU is acceptable since
|
||||||
|
429 fallbacks are rare and the original sequential code also paid the
|
||||||
|
full build cost on the same path.
|
||||||
|
"""
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
def _classify_stream_exception(
|
def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
|
|
@ -655,6 +720,7 @@ async def _stream_agent_events(
|
||||||
fallback_commit_created_by_id: str | None = None,
|
fallback_commit_created_by_id: str | None = None,
|
||||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
fallback_commit_thread_id: int | None = None,
|
fallback_commit_thread_id: int | None = None,
|
||||||
|
runtime_context: Any = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Shared async generator that streams and formats astream_events from the agent.
|
"""Shared async generator that streams and formats astream_events from the agent.
|
||||||
|
|
||||||
|
|
@ -760,7 +826,18 @@ async def _stream_agent_events(
|
||||||
return event
|
return event
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async for event in agent.astream_events(input_data, config=config, version="v2"):
|
# Per-invocation runtime context (Phase 1.5). When supplied,
|
||||||
|
# ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids``
|
||||||
|
# from ``runtime.context`` instead of its constructor closure — the
|
||||||
|
# prerequisite that lets the compiled-agent cache (Phase 1) reuse a
|
||||||
|
# single graph across turns. Astream_events_kwargs stays empty when
|
||||||
|
# callers leave ``runtime_context`` as ``None`` to preserve the
|
||||||
|
# legacy code path bit-for-bit.
|
||||||
|
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
|
||||||
|
if runtime_context is not None:
|
||||||
|
astream_kwargs["context"] = runtime_context
|
||||||
|
|
||||||
|
async for event in agent.astream_events(input_data, **astream_kwargs):
|
||||||
event_type = event.get("event", "")
|
event_type = event.get("event", "")
|
||||||
|
|
||||||
if event_type == "on_chat_model_stream":
|
if event_type == "on_chat_model_stream":
|
||||||
|
|
@ -1506,10 +1583,10 @@ async def _stream_agent_events(
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
else "Podcast"
|
else "Podcast"
|
||||||
)
|
)
|
||||||
if podcast_status == "processing":
|
if podcast_status in ("pending", "generating", "processing"):
|
||||||
completed_items = [
|
completed_items = [
|
||||||
f"Title: {podcast_title}",
|
f"Title: {podcast_title}",
|
||||||
"Audio generation started",
|
"Podcast generation started",
|
||||||
"Processing in background...",
|
"Processing in background...",
|
||||||
]
|
]
|
||||||
elif podcast_status == "already_generating":
|
elif podcast_status == "already_generating":
|
||||||
|
|
@ -1518,7 +1595,7 @@ async def _stream_agent_events(
|
||||||
"Podcast already in progress",
|
"Podcast already in progress",
|
||||||
"Please wait for it to complete",
|
"Please wait for it to complete",
|
||||||
]
|
]
|
||||||
elif podcast_status == "error":
|
elif podcast_status in ("failed", "error"):
|
||||||
error_msg = (
|
error_msg = (
|
||||||
tool_output.get("error", "Unknown error")
|
tool_output.get("error", "Unknown error")
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
|
|
@ -1528,6 +1605,11 @@ async def _stream_agent_events(
|
||||||
f"Title: {podcast_title}",
|
f"Title: {podcast_title}",
|
||||||
f"Error: {error_msg[:50]}",
|
f"Error: {error_msg[:50]}",
|
||||||
]
|
]
|
||||||
|
elif podcast_status in ("ready", "success"):
|
||||||
|
completed_items = [
|
||||||
|
f"Title: {podcast_title}",
|
||||||
|
"Podcast ready",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
completed_items = last_active_step_items
|
completed_items = last_active_step_items
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
|
|
@ -1710,20 +1792,28 @@ async def _stream_agent_events(
|
||||||
if isinstance(tool_output, dict)
|
if isinstance(tool_output, dict)
|
||||||
else {"result": tool_output},
|
else {"result": tool_output},
|
||||||
)
|
)
|
||||||
if (
|
if isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||||
isinstance(tool_output, dict)
|
"pending",
|
||||||
and tool_output.get("status") == "success"
|
"generating",
|
||||||
|
"processing",
|
||||||
|
):
|
||||||
|
yield streaming_service.format_terminal_info(
|
||||||
|
f"Podcast queued: {tool_output.get('title', 'Podcast')}",
|
||||||
|
"success",
|
||||||
|
)
|
||||||
|
elif isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||||
|
"ready",
|
||||||
|
"success",
|
||||||
):
|
):
|
||||||
yield streaming_service.format_terminal_info(
|
yield streaming_service.format_terminal_info(
|
||||||
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
||||||
"success",
|
"success",
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||||
error_msg = (
|
"failed",
|
||||||
tool_output.get("error", "Unknown error")
|
"error",
|
||||||
if isinstance(tool_output, dict)
|
):
|
||||||
else "Unknown error"
|
error_msg = tool_output.get("error", "Unknown error")
|
||||||
)
|
|
||||||
yield streaming_service.format_terminal_info(
|
yield streaming_service.format_terminal_info(
|
||||||
f"Podcast generation failed: {error_msg}",
|
f"Podcast generation failed: {error_msg}",
|
||||||
"error",
|
"error",
|
||||||
|
|
@ -2165,10 +2255,10 @@ async def _stream_agent_events(
|
||||||
result.agent_called_update_memory = called_update_memory
|
result.agent_called_update_memory = called_update_memory
|
||||||
_log_file_contract("turn_outcome", result)
|
_log_file_contract("turn_outcome", result)
|
||||||
|
|
||||||
is_interrupted = state.tasks and any(task.interrupts for task in state.tasks)
|
interrupt_value = _first_interrupt_value(state)
|
||||||
if is_interrupted:
|
if interrupt_value is not None:
|
||||||
result.is_interrupted = True
|
result.is_interrupted = True
|
||||||
result.interrupt_value = state.tasks[0].interrupts[0].value
|
result.interrupt_value = interrupt_value
|
||||||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2292,6 +2382,11 @@ async def stream_new_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
# Image-bearing turns force the Auto-pin resolver to filter the
|
||||||
|
# candidate pool to vision-capable cfgs (and force-repin a
|
||||||
|
# text-only existing pin). For explicit selections this flag is
|
||||||
|
# a no-op — the resolver returns the user's chosen id unchanged.
|
||||||
|
_requires_image_input = bool(user_image_data_urls)
|
||||||
try:
|
try:
|
||||||
llm_config_id = (
|
llm_config_id = (
|
||||||
await resolve_or_get_pinned_llm_config_id(
|
await resolve_or_get_pinned_llm_config_id(
|
||||||
|
|
@ -2300,13 +2395,29 @@ async def stream_new_chat(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
selected_llm_config_id=llm_config_id,
|
selected_llm_config_id=llm_config_id,
|
||||||
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
|
# Auto-pin's "no vision-capable cfg" path raises a ValueError
|
||||||
|
# whose message we map to the friendly image-input SSE error
|
||||||
|
# so the user sees the same message regardless of whether
|
||||||
|
# the gate fired in Auto-mode or in the agent_config check
|
||||||
|
# below.
|
||||||
|
error_code = (
|
||||||
|
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||||
|
if _requires_image_input and "vision-capable" in str(pin_error)
|
||||||
|
else "SERVER_ERROR"
|
||||||
|
)
|
||||||
|
error_kind = (
|
||||||
|
"user_error"
|
||||||
|
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||||
|
else "server_error"
|
||||||
|
)
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=str(pin_error),
|
message=str(pin_error),
|
||||||
error_kind="server_error",
|
error_kind=error_kind,
|
||||||
error_code="SERVER_ERROR",
|
error_code=error_code,
|
||||||
)
|
)
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
@ -2326,6 +2437,50 @@ async def stream_new_chat(
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Capability safety net: a turn carrying user-uploaded images
|
||||||
|
# cannot be routed to a chat config that LiteLLM's authoritative
|
||||||
|
# model map *explicitly* marks as text-only (``supports_vision``
|
||||||
|
# set to False). The check is intentionally narrow — it only
|
||||||
|
# fires when LiteLLM is *certain* the model can't accept image
|
||||||
|
# input. Unknown / unmapped / vision-capable models pass
|
||||||
|
# through. Without this guard a known-text-only model would 404
|
||||||
|
# at the provider with ``"No endpoints found that support image
|
||||||
|
# input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
|
||||||
|
# failing here lets us return a friendly message that tells the
|
||||||
|
# user what to change.
|
||||||
|
if user_image_data_urls and agent_config is not None:
|
||||||
|
from app.services.provider_capabilities import (
|
||||||
|
is_known_text_only_chat_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_litellm_params = agent_config.litellm_params or {}
|
||||||
|
agent_base_model = (
|
||||||
|
agent_litellm_params.get("base_model")
|
||||||
|
if isinstance(agent_litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if is_known_text_only_chat_model(
|
||||||
|
provider=agent_config.provider,
|
||||||
|
model_name=agent_config.model_name,
|
||||||
|
base_model=agent_base_model,
|
||||||
|
custom_provider=agent_config.custom_provider,
|
||||||
|
):
|
||||||
|
model_label = (
|
||||||
|
agent_config.config_name or agent_config.model_name or "model"
|
||||||
|
)
|
||||||
|
yield _emit_stream_error(
|
||||||
|
message=(
|
||||||
|
f"The selected model ({model_label}) does not support "
|
||||||
|
"image input. Switch to a vision-capable model "
|
||||||
|
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
|
||||||
|
"attachment and try again."
|
||||||
|
),
|
||||||
|
error_kind="user_error",
|
||||||
|
error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||||
|
)
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
# Premium quota reservation for pinned premium model only.
|
# Premium quota reservation for pinned premium model only.
|
||||||
_needs_premium_quota = (
|
_needs_premium_quota = (
|
||||||
agent_config is not None and user_id and agent_config.is_premium
|
agent_config is not None and user_id and agent_config.is_premium
|
||||||
|
|
@ -2366,6 +2521,7 @@ async def stream_new_chat(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
selected_llm_config_id=0,
|
selected_llm_config_id=0,
|
||||||
force_repin_free=True,
|
force_repin_free=True,
|
||||||
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
|
|
@ -2440,23 +2596,102 @@ async def stream_new_chat(
|
||||||
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
|
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
|
||||||
# title-generation LLM calls fan out and each independently hit the
|
# title-generation LLM calls fan out and each independently hit the
|
||||||
# same upstream rate limit.
|
# same upstream rate limit.
|
||||||
if (
|
#
|
||||||
|
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
|
||||||
|
# and is independent of the agent build (CPU-bound, ~5-7s). They used
|
||||||
|
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
|
||||||
|
# We now kick off preflight as a background task FIRST, then run the
|
||||||
|
# synchronous setup work and the agent build in parallel. In the
|
||||||
|
# success path (the common case) total wall time drops to roughly
|
||||||
|
# ``max(preflight, build)`` — the preflight finishes during the
|
||||||
|
# agent compile and we just consume its result. In the rare 429
|
||||||
|
# path the speculative build is awaited to completion (so its
|
||||||
|
# session usage is fully released) via
|
||||||
|
# :func:`_settle_speculative_agent_build`, then discarded, and
|
||||||
|
# we fall back to the original repin-and-rebuild flow.
|
||||||
|
preflight_needed = (
|
||||||
requested_llm_config_id == 0
|
requested_llm_config_id == 0
|
||||||
and llm_config_id < 0
|
and llm_config_id < 0
|
||||||
and not is_recently_healthy(llm_config_id)
|
and not is_recently_healthy(llm_config_id)
|
||||||
):
|
)
|
||||||
|
preflight_task: asyncio.Task[None] | None = None
|
||||||
|
_t_preflight = 0.0
|
||||||
|
if preflight_needed:
|
||||||
_t_preflight = time.perf_counter()
|
_t_preflight = time.perf_counter()
|
||||||
|
preflight_task = asyncio.create_task(
|
||||||
|
_preflight_llm(llm),
|
||||||
|
name=f"auto_pin_preflight:{llm_config_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create connector service
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
|
firecrawl_api_key = None
|
||||||
|
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||||
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||||
|
)
|
||||||
|
if webcrawler_connector and webcrawler_connector.config:
|
||||||
|
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
checkpointer = await get_checkpointer()
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
# Speculative agent build — runs in parallel with the preflight
|
||||||
|
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
|
||||||
|
# if preflight reports 429 we will discard this future and rebuild
|
||||||
|
# against the freshly pinned config below.
|
||||||
|
agent_build_task = asyncio.create_task(
|
||||||
|
create_surfsense_deep_agent(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
agent_config=agent_config,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
),
|
||||||
|
name="agent_build:stream_new_chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
agent: Any = None
|
||||||
|
if preflight_task is not None:
|
||||||
try:
|
try:
|
||||||
await _preflight_llm(llm)
|
await preflight_task
|
||||||
mark_healthy(llm_config_id)
|
mark_healthy(llm_config_id)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs",
|
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
time.perf_counter() - _t_preflight,
|
time.perf_counter() - _t_preflight,
|
||||||
)
|
)
|
||||||
except Exception as preflight_exc:
|
except Exception as preflight_exc:
|
||||||
|
# Both branches below need the session: the non-429 path
|
||||||
|
# may unwind via cleanup that uses ``session``, and the
|
||||||
|
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
|
||||||
|
# against it. Wait for the speculative build to release its
|
||||||
|
# session usage before we proceed.
|
||||||
|
await _settle_speculative_agent_build(agent_build_task)
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
if not _is_provider_rate_limited(preflight_exc):
|
||||||
raise
|
raise
|
||||||
|
# 429: speculative agent is discarded; run the original
|
||||||
|
# repin → reload → rebuild path against the freshly
|
||||||
|
# pinned config.
|
||||||
previous_config_id = llm_config_id
|
previous_config_id = llm_config_id
|
||||||
mark_runtime_cooldown(
|
mark_runtime_cooldown(
|
||||||
previous_config_id, reason="preflight_rate_limited"
|
previous_config_id, reason="preflight_rate_limited"
|
||||||
|
|
@ -2470,6 +2705,7 @@ async def stream_new_chat(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
selected_llm_config_id=0,
|
selected_llm_config_id=0,
|
||||||
exclude_config_ids={previous_config_id},
|
exclude_config_ids={previous_config_id},
|
||||||
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
|
|
@ -2518,31 +2754,8 @@ async def stream_new_chat(
|
||||||
"fallback_config_id": llm_config_id,
|
"fallback_config_id": llm_config_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Rebuild against the new llm/agent_config. Sequential
|
||||||
# Create connector service
|
# here because we no longer have anything to overlap with.
|
||||||
_t0 = time.perf_counter()
|
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
|
||||||
|
|
||||||
firecrawl_api_key = None
|
|
||||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
|
||||||
)
|
|
||||||
if webcrawler_connector and webcrawler_connector.config:
|
|
||||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
checkpointer = await get_checkpointer()
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
|
||||||
)
|
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -2558,6 +2771,11 @@ async def stream_new_chat(
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
filesystem_selection=filesystem_selection,
|
filesystem_selection=filesystem_selection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if agent is None:
|
||||||
|
# Either no preflight was needed, or preflight succeeded —
|
||||||
|
# in both cases the speculative build is the agent we want.
|
||||||
|
agent = await agent_build_task
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
@ -2804,6 +3022,7 @@ async def stream_new_chat(
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
|
||||||
from app.services.llm_router_service import LLMRouterService
|
from app.services.llm_router_service import LLMRouterService
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.services.token_tracking_service import _turn_accumulator
|
from app.services.token_tracking_service import _turn_accumulator
|
||||||
|
|
||||||
_turn_accumulator.set(None)
|
_turn_accumulator.set(None)
|
||||||
|
|
@ -2824,11 +3043,32 @@ async def stream_new_chat(
|
||||||
model="auto", messages=messages
|
model="auto", messages=messages
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Apply the same ``api_base`` cascade chat / vision /
|
||||||
|
# image-gen call sites use so we never inherit
|
||||||
|
# ``litellm.api_base`` (commonly set by
|
||||||
|
# ``AZURE_OPENAI_ENDPOINT``) when the chat config
|
||||||
|
# itself ships an empty ``api_base``. Without this
|
||||||
|
# the title-gen on an OpenRouter chat config would
|
||||||
|
# 404 against the inherited Azure endpoint — see
|
||||||
|
# ``provider_api_base`` docstring for the same
|
||||||
|
# bug repro on the image-gen / vision paths.
|
||||||
|
raw_model = getattr(llm, "model", "") or ""
|
||||||
|
provider_prefix = (
|
||||||
|
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||||
|
)
|
||||||
|
provider_value = (
|
||||||
|
agent_config.provider if agent_config is not None else None
|
||||||
|
)
|
||||||
|
title_api_base = resolve_api_base(
|
||||||
|
provider=provider_value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=getattr(llm, "api_base", None),
|
||||||
|
)
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model=llm.model,
|
model=raw_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=getattr(llm, "api_key", None),
|
api_key=getattr(llm, "api_key", None),
|
||||||
api_base=getattr(llm, "api_base", None),
|
api_base=title_api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
usage_info = None
|
usage_info = None
|
||||||
|
|
@ -2862,6 +3102,18 @@ async def stream_new_chat(
|
||||||
|
|
||||||
title_emitted = False
|
title_emitted = False
|
||||||
|
|
||||||
|
# Build the per-invocation runtime context (Phase 1.5).
|
||||||
|
# ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
|
||||||
|
# via ``runtime.context.mentioned_document_ids`` instead of its
|
||||||
|
# ``__init__`` closure — that way the same compiled-agent instance
|
||||||
|
# can serve multiple turns with different mention lists.
|
||||||
|
runtime_context = SurfSenseContextSchema(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||||
|
request_id=request_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
)
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
runtime_rate_limit_recovered = False
|
runtime_rate_limit_recovered = False
|
||||||
|
|
@ -2885,6 +3137,7 @@ async def stream_new_chat(
|
||||||
else FilesystemMode.CLOUD
|
else FilesystemMode.CLOUD
|
||||||
),
|
),
|
||||||
fallback_commit_thread_id=chat_id,
|
fallback_commit_thread_id=chat_id,
|
||||||
|
runtime_context=runtime_context,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -2953,6 +3206,7 @@ async def stream_new_chat(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
selected_llm_config_id=0,
|
selected_llm_config_id=0,
|
||||||
exclude_config_ids={previous_config_id},
|
exclude_config_ids={previous_config_id},
|
||||||
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
|
|
||||||
|
|
@ -3499,21 +3753,75 @@ async def stream_resume_chat(
|
||||||
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
|
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
|
||||||
# one cheap probe before the agent is rebuilt so a 429'd pin gets
|
# one cheap probe before the agent is rebuilt so a 429'd pin gets
|
||||||
# repinned without burning planner/classifier/title calls first.
|
# repinned without burning planner/classifier/title calls first.
|
||||||
if (
|
# See ``stream_new_chat`` for the full rationale on the speculative
|
||||||
|
# parallel build pattern below.
|
||||||
|
preflight_needed = (
|
||||||
requested_llm_config_id == 0
|
requested_llm_config_id == 0
|
||||||
and llm_config_id < 0
|
and llm_config_id < 0
|
||||||
and not is_recently_healthy(llm_config_id)
|
and not is_recently_healthy(llm_config_id)
|
||||||
):
|
)
|
||||||
|
preflight_task: asyncio.Task[None] | None = None
|
||||||
|
_t_preflight = 0.0
|
||||||
|
if preflight_needed:
|
||||||
_t_preflight = time.perf_counter()
|
_t_preflight = time.perf_counter()
|
||||||
|
preflight_task = asyncio.create_task(
|
||||||
|
_preflight_llm(llm),
|
||||||
|
name=f"auto_pin_preflight_resume:{llm_config_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
|
firecrawl_api_key = None
|
||||||
|
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||||
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||||
|
)
|
||||||
|
if webcrawler_connector and webcrawler_connector.config:
|
||||||
|
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
checkpointer = await get_checkpointer()
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
agent_build_task = asyncio.create_task(
|
||||||
|
create_surfsense_deep_agent(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
agent_config=agent_config,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
),
|
||||||
|
name="agent_build:stream_resume",
|
||||||
|
)
|
||||||
|
|
||||||
|
agent: Any = None
|
||||||
|
if preflight_task is not None:
|
||||||
try:
|
try:
|
||||||
await _preflight_llm(llm)
|
await preflight_task
|
||||||
mark_healthy(llm_config_id)
|
mark_healthy(llm_config_id)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs",
|
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
time.perf_counter() - _t_preflight,
|
time.perf_counter() - _t_preflight,
|
||||||
)
|
)
|
||||||
except Exception as preflight_exc:
|
except Exception as preflight_exc:
|
||||||
|
# Same session-safety rationale as ``stream_new_chat``.
|
||||||
|
await _settle_speculative_agent_build(agent_build_task)
|
||||||
if not _is_provider_rate_limited(preflight_exc):
|
if not _is_provider_rate_limited(preflight_exc):
|
||||||
raise
|
raise
|
||||||
previous_config_id = llm_config_id
|
previous_config_id = llm_config_id
|
||||||
|
|
@ -3573,30 +3881,6 @@ async def stream_resume_chat(
|
||||||
"fallback_config_id": llm_config_id,
|
"fallback_config_id": llm_config_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
|
||||||
|
|
||||||
firecrawl_api_key = None
|
|
||||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
|
||||||
)
|
|
||||||
if webcrawler_connector and webcrawler_connector.config:
|
|
||||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
)
|
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
checkpointer = await get_checkpointer()
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
|
||||||
)
|
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -3610,6 +3894,9 @@ async def stream_resume_chat(
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
filesystem_selection=filesystem_selection,
|
filesystem_selection=filesystem_selection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if agent is None:
|
||||||
|
agent = await agent_build_task
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
@ -3650,6 +3937,16 @@ async def stream_resume_chat(
|
||||||
)
|
)
|
||||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
|
# Resume path doesn't carry new ``mentioned_document_ids`` —
|
||||||
|
# those are seeded in the original turn. We still pass a
|
||||||
|
# context so future middleware extensions (Phase 2) can rely on
|
||||||
|
# ``runtime.context`` always being populated.
|
||||||
|
runtime_context = SurfSenseContextSchema(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
request_id=request_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
)
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
runtime_rate_limit_recovered = False
|
runtime_rate_limit_recovered = False
|
||||||
|
|
@ -3670,6 +3967,7 @@ async def stream_resume_chat(
|
||||||
else FilesystemMode.CLOUD
|
else FilesystemMode.CLOUD
|
||||||
),
|
),
|
||||||
fallback_commit_thread_id=chat_id,
|
fallback_commit_thread_id=chat_id,
|
||||||
|
runtime_context=runtime_context,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
check_duplicate_document_by_hash,
|
check_duplicate_document_by_hash,
|
||||||
|
|
@ -44,6 +42,10 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _format_calendar_event_to_markdown(event: dict) -> str:
|
||||||
|
return GoogleCalendarConnector.format_event_to_markdown(None, event)
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_doc(
|
def _build_connector_doc(
|
||||||
event: dict,
|
event: dict,
|
||||||
event_markdown: str,
|
event_markdown: str,
|
||||||
|
|
@ -150,7 +152,14 @@ async def index_google_calendar_events(
|
||||||
)
|
)
|
||||||
return 0, 0, f"Connector with ID {connector_id} not found"
|
return 0, 0, f"Connector with ID {connector_id} not found"
|
||||||
|
|
||||||
# ── Credential building ───────────────────────────────────────
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
|
calendar_client = None
|
||||||
|
composio_service = None
|
||||||
|
connected_account_id = None
|
||||||
|
|
||||||
|
# ── Credential/client building ────────────────────────────────
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
|
|
@ -161,7 +170,7 @@ async def index_google_calendar_events(
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "MissingComposioAccount"},
|
||||||
)
|
)
|
||||||
return 0, 0, "Composio connected_account_id not found"
|
return 0, 0, "Composio connected_account_id not found"
|
||||||
credentials = build_composio_credentials(connected_account_id)
|
composio_service = ComposioService()
|
||||||
else:
|
else:
|
||||||
config_data = connector.config
|
config_data = connector.config
|
||||||
|
|
||||||
|
|
@ -229,6 +238,7 @@ async def index_google_calendar_events(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_composio_connector:
|
||||||
calendar_client = GoogleCalendarConnector(
|
calendar_client = GoogleCalendarConnector(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
session=session,
|
session=session,
|
||||||
|
|
@ -300,6 +310,23 @@ async def index_google_calendar_events(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if is_composio_connector:
|
||||||
|
start_dt = parse_date_flexible(start_date_str).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
end_dt = parse_date_flexible(end_date_str).replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=0
|
||||||
|
)
|
||||||
|
events, error = await composio_service.get_calendar_events(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
time_min=start_dt.isoformat(),
|
||||||
|
time_max=end_dt.isoformat(),
|
||||||
|
max_results=250,
|
||||||
|
)
|
||||||
|
if not events and not error:
|
||||||
|
error = "No events found in the specified date range."
|
||||||
|
else:
|
||||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||||
start_date=start_date_str, end_date=end_date_str
|
start_date=start_date_str, end_date=end_date_str
|
||||||
)
|
)
|
||||||
|
|
@ -381,7 +408,7 @@ async def index_google_calendar_events(
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_markdown = calendar_client.format_event_to_markdown(event)
|
event_markdown = _format_calendar_event_to_markdown(event)
|
||||||
if not event_markdown.strip():
|
if not event_markdown.strip():
|
||||||
logger.warning(f"Skipping event with no content: {event_summary}")
|
logger.warning(f"Skipping event with no content: {event_summary}")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import String, cast, select
|
from sqlalchemy import String, cast, select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
@ -37,6 +39,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.page_limit_service import PageLimitService
|
from app.services.page_limit_service import PageLimitService
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
|
@ -45,10 +48,7 @@ from app.tasks.connector_indexers.base import (
|
||||||
get_connector_by_id,
|
get_connector_by_id,
|
||||||
update_connector_last_indexed,
|
update_connector_last_indexed,
|
||||||
)
|
)
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
ACCEPTED_DRIVE_CONNECTOR_TYPES = {
|
||||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
|
@ -61,6 +61,209 @@ HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ComposioDriveClient:
|
||||||
|
"""Google Drive client facade backed by Composio tool execution.
|
||||||
|
|
||||||
|
Composio-managed OAuth connections can execute tools without exposing raw
|
||||||
|
OAuth tokens through connected account state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
connector_id: int,
|
||||||
|
connected_account_id: str,
|
||||||
|
entity_id: str,
|
||||||
|
):
|
||||||
|
self.session = session
|
||||||
|
self.connector_id = connector_id
|
||||||
|
self.connected_account_id = connected_account_id
|
||||||
|
self.entity_id = entity_id
|
||||||
|
self.composio = ComposioService()
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
query: str = "",
|
||||||
|
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
||||||
|
page_size: int = 100,
|
||||||
|
page_token: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"page_size": min(page_size, 100),
|
||||||
|
"fields": fields,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
params["q"] = query
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_LIST_FILES",
|
||||||
|
params=params,
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return [], None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
files = []
|
||||||
|
next_token = None
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
files = inner_data.get("files", [])
|
||||||
|
next_token = inner_data.get("nextPageToken") or inner_data.get(
|
||||||
|
"next_page_token"
|
||||||
|
)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
files = data
|
||||||
|
|
||||||
|
return files, next_token, None
|
||||||
|
|
||||||
|
async def get_file_metadata(
|
||||||
|
self, file_id: str, fields: str = "*"
|
||||||
|
) -> tuple[dict[str, Any] | None, str | None]:
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
|
||||||
|
params={"file_id": file_id, "fields": fields},
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
return inner_data, None
|
||||||
|
|
||||||
|
return None, "Could not extract metadata from Composio response"
|
||||||
|
|
||||||
|
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||||
|
return await self._download_file_content(file_id)
|
||||||
|
|
||||||
|
async def download_file_to_disk(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
dest_path: str,
|
||||||
|
chunksize: int = 5 * 1024 * 1024,
|
||||||
|
) -> str | None:
|
||||||
|
del chunksize
|
||||||
|
content, error = await self.download_file(file_id)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
if content is None:
|
||||||
|
return "No content returned from Composio"
|
||||||
|
Path(dest_path).write_bytes(content)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def export_google_file(
|
||||||
|
self, file_id: str, mime_type: str
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
return await self._download_file_content(file_id, mime_type=mime_type)
|
||||||
|
|
||||||
|
async def _download_file_content(
|
||||||
|
self, file_id: str, mime_type: str | None = None
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
params: dict[str, Any] = {"file_id": file_id}
|
||||||
|
if mime_type:
|
||||||
|
params["mime_type"] = mime_type
|
||||||
|
|
||||||
|
result = await self.composio.execute_tool(
|
||||||
|
connected_account_id=self.connected_account_id,
|
||||||
|
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
|
||||||
|
params=params,
|
||||||
|
entity_id=self.entity_id,
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown error")
|
||||||
|
|
||||||
|
return self._read_download_result(result.get("data"))
|
||||||
|
|
||||||
|
def _read_download_result(self, data: Any) -> tuple[bytes | None, str | None]:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data, None
|
||||||
|
|
||||||
|
file_path: str | None = None
|
||||||
|
if isinstance(data, str):
|
||||||
|
file_path = data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
inner_data = data.get("data", data)
|
||||||
|
if isinstance(inner_data, dict):
|
||||||
|
for key in ("file_path", "downloaded_file_content", "path", "uri"):
|
||||||
|
value = inner_data.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
file_path = value
|
||||||
|
break
|
||||||
|
if isinstance(value, dict):
|
||||||
|
nested = (
|
||||||
|
value.get("file_path")
|
||||||
|
or value.get("downloaded_file_content")
|
||||||
|
or value.get("path")
|
||||||
|
or value.get("uri")
|
||||||
|
or value.get("s3url")
|
||||||
|
)
|
||||||
|
if isinstance(nested, str):
|
||||||
|
file_path = nested
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
return None, "No file path/content returned from Composio"
|
||||||
|
|
||||||
|
if file_path.startswith(("http://", "https://")):
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
with urllib.request.urlopen(file_path, timeout=60) as response:
|
||||||
|
return response.read(), None
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to download Composio file URL: {e!s}"
|
||||||
|
|
||||||
|
path_obj = Path(file_path)
|
||||||
|
if path_obj.is_absolute() or ".composio" in str(path_obj):
|
||||||
|
if not path_obj.exists():
|
||||||
|
return None, f"File not found at path: {file_path}"
|
||||||
|
return path_obj.read_bytes(), None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
return base64.b64decode(file_path), None
|
||||||
|
except Exception:
|
||||||
|
return file_path.encode("utf-8"), None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_drive_client_for_connector(
|
||||||
|
session: AsyncSession,
|
||||||
|
connector_id: int,
|
||||||
|
connector: object,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[GoogleDriveClient | ComposioDriveClient | None, str | None]:
|
||||||
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not connected_account_id:
|
||||||
|
return None, (
|
||||||
|
f"Composio connected_account_id not found for connector {connector_id}"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
ComposioDriveClient(
|
||||||
|
session,
|
||||||
|
connector_id,
|
||||||
|
connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and not config.SECRET_KEY:
|
||||||
|
return None, "SECRET_KEY not configured but credentials are marked as encrypted"
|
||||||
|
|
||||||
|
return GoogleDriveClient(session, connector_id), None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -927,34 +1130,17 @@ async def index_google_drive_files(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
error_msg,
|
client_error or "Failed to initialize Google Drive client",
|
||||||
"Missing Composio account",
|
"Missing connector credentials",
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "ClientInitializationError"},
|
||||||
)
|
|
||||||
return 0, 0, error_msg, 0
|
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
"SECRET_KEY not configured but credentials are encrypted",
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
|
return 0, 0, client_error, 0
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||||
|
|
@ -963,10 +1149,6 @@ async def index_google_drive_files(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
if not folder_id:
|
if not folder_id:
|
||||||
error_msg = "folder_id is required for Google Drive indexing"
|
error_msg = "folder_id is required for Google Drive indexing"
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -979,8 +1161,14 @@ async def index_google_drive_files(
|
||||||
|
|
||||||
folder_tokens = connector.config.get("folder_tokens", {})
|
folder_tokens = connector.config.get("folder_tokens", {})
|
||||||
start_page_token = folder_tokens.get(target_folder_id)
|
start_page_token = folder_tokens.get(target_folder_id)
|
||||||
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
can_use_delta = (
|
can_use_delta = (
|
||||||
use_delta_sync and start_page_token and connector.last_indexed_at
|
not is_composio_connector
|
||||||
|
and use_delta_sync
|
||||||
|
and start_page_token
|
||||||
|
and connector.last_indexed_at
|
||||||
)
|
)
|
||||||
|
|
||||||
documents_unsupported = 0
|
documents_unsupported = 0
|
||||||
|
|
@ -1051,6 +1239,15 @@ async def index_google_drive_files(
|
||||||
)
|
)
|
||||||
|
|
||||||
if documents_indexed > 0 or can_use_delta:
|
if documents_indexed > 0 or can_use_delta:
|
||||||
|
if isinstance(drive_client, ComposioDriveClient):
|
||||||
|
(
|
||||||
|
new_token,
|
||||||
|
token_error,
|
||||||
|
) = await drive_client.composio.get_drive_start_page_token(
|
||||||
|
drive_client.connected_account_id,
|
||||||
|
drive_client.entity_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
new_token, token_error = await get_start_page_token(drive_client)
|
new_token, token_error = await get_start_page_token(drive_client)
|
||||||
if new_token and not token_error:
|
if new_token and not token_error:
|
||||||
await session.refresh(connector)
|
await session.refresh(connector)
|
||||||
|
|
@ -1137,32 +1334,17 @@ async def index_google_drive_single_file(
|
||||||
)
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
error_msg,
|
client_error or "Failed to initialize Google Drive client",
|
||||||
"Missing Composio account",
|
"Missing connector credentials",
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "ClientInitializationError"},
|
||||||
)
|
|
||||||
return 0, error_msg
|
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
"SECRET_KEY not configured but credentials are encrypted",
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
0,
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
|
||||||
)
|
)
|
||||||
|
return 0, client_error
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
connector_enable_vision_llm = getattr(connector, "enable_vision_llm", False)
|
||||||
|
|
@ -1171,10 +1353,6 @@ async def index_google_drive_single_file(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
file, error = await get_file_by_id(drive_client, file_id)
|
file, error = await get_file_by_id(drive_client, file_id)
|
||||||
if error or not file:
|
if error or not file:
|
||||||
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
||||||
|
|
@ -1276,30 +1454,16 @@ async def index_google_drive_selected_files(
|
||||||
)
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
||||||
pre_built_credentials = None
|
drive_client, client_error = _build_drive_client_for_connector(
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
session, connector_id, connector, user_id
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
)
|
||||||
if not connected_account_id:
|
if client_error or not drive_client:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
error_msg = client_error or "Failed to initialize Google Drive client"
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
error_msg,
|
error_msg,
|
||||||
"Missing Composio account",
|
"Missing connector credentials",
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "ClientInitializationError"},
|
||||||
)
|
|
||||||
return 0, 0, [error_msg]
|
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
|
||||||
else:
|
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
|
||||||
error_msg = (
|
|
||||||
"SECRET_KEY not configured but credentials are marked as encrypted"
|
|
||||||
)
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
error_msg,
|
|
||||||
"Missing SECRET_KEY",
|
|
||||||
{"error_type": "MissingSecretKey"},
|
|
||||||
)
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
||||||
|
|
@ -1310,10 +1474,6 @@ async def index_google_drive_selected_files(
|
||||||
from app.services.llm_service import get_vision_llm
|
from app.services.llm_service import get_vision_llm
|
||||||
|
|
||||||
vision_llm = await get_vision_llm(session, search_space_id)
|
vision_llm = await get_vision_llm(session, search_space_id)
|
||||||
drive_client = GoogleDriveClient(
|
|
||||||
session, connector_id, credentials=pre_built_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
indexed, skipped, unsupported, errors = await _index_selected_files(
|
indexed, skipped, unsupported, errors = await _index_selected_files(
|
||||||
drive_client,
|
drive_client,
|
||||||
session,
|
session,
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,10 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
IndexingPipelineService,
|
IndexingPipelineService,
|
||||||
PlaceholderInfo,
|
PlaceholderInfo,
|
||||||
)
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
|
|
||||||
build_composio_credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
calculate_date_range,
|
calculate_date_range,
|
||||||
|
|
@ -44,6 +42,62 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_composio_gmail_message(message: dict) -> dict:
|
||||||
|
if message.get("payload"):
|
||||||
|
return message
|
||||||
|
|
||||||
|
headers = []
|
||||||
|
header_values = {
|
||||||
|
"Subject": message.get("subject"),
|
||||||
|
"From": message.get("from") or message.get("sender"),
|
||||||
|
"To": message.get("to") or message.get("recipient"),
|
||||||
|
"Date": message.get("date"),
|
||||||
|
}
|
||||||
|
for name, value in header_values.items():
|
||||||
|
if value:
|
||||||
|
headers.append({"name": name, "value": value})
|
||||||
|
|
||||||
|
return {
|
||||||
|
**message,
|
||||||
|
"id": message.get("id")
|
||||||
|
or message.get("message_id")
|
||||||
|
or message.get("messageId"),
|
||||||
|
"threadId": message.get("threadId") or message.get("thread_id"),
|
||||||
|
"payload": {"headers": headers},
|
||||||
|
"snippet": message.get("snippet", ""),
|
||||||
|
"messageText": message.get("messageText") or message.get("body") or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gmail_message_to_markdown(message: dict) -> str:
|
||||||
|
headers = {
|
||||||
|
header.get("name", "").lower(): header.get("value", "")
|
||||||
|
for header in message.get("payload", {}).get("headers", [])
|
||||||
|
if isinstance(header, dict)
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "No Subject")
|
||||||
|
from_email = headers.get("from", "Unknown Sender")
|
||||||
|
to_email = headers.get("to", "Unknown Recipient")
|
||||||
|
date_str = headers.get("date", "Unknown Date")
|
||||||
|
message_text = (
|
||||||
|
message.get("messageText")
|
||||||
|
or message.get("body")
|
||||||
|
or message.get("text")
|
||||||
|
or message.get("snippet", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"# {subject}\n\n"
|
||||||
|
f"**From:** {from_email}\n"
|
||||||
|
f"**To:** {to_email}\n"
|
||||||
|
f"**Date:** {date_str}\n\n"
|
||||||
|
f"## Message Content\n\n{message_text}\n\n"
|
||||||
|
f"## Message Details\n\n"
|
||||||
|
f"- **Message ID:** {message.get('id', 'Unknown')}\n"
|
||||||
|
f"- **Thread ID:** {message.get('threadId', 'Unknown')}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_doc(
|
def _build_connector_doc(
|
||||||
message: dict,
|
message: dict,
|
||||||
markdown_content: str,
|
markdown_content: str,
|
||||||
|
|
@ -162,7 +216,14 @@ async def index_google_gmail_messages(
|
||||||
)
|
)
|
||||||
return 0, 0, error_msg
|
return 0, 0, error_msg
|
||||||
|
|
||||||
# ── Credential building ───────────────────────────────────────
|
is_composio_connector = (
|
||||||
|
connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
)
|
||||||
|
gmail_connector = None
|
||||||
|
composio_service = None
|
||||||
|
connected_account_id = None
|
||||||
|
|
||||||
|
# ── Credential/client building ────────────────────────────────
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
|
|
@ -173,7 +234,7 @@ async def index_google_gmail_messages(
|
||||||
{"error_type": "MissingComposioAccount"},
|
{"error_type": "MissingComposioAccount"},
|
||||||
)
|
)
|
||||||
return 0, 0, "Composio connected_account_id not found"
|
return 0, 0, "Composio connected_account_id not found"
|
||||||
credentials = build_composio_credentials(connected_account_id)
|
composio_service = ComposioService()
|
||||||
else:
|
else:
|
||||||
config_data = connector.config
|
config_data = connector.config
|
||||||
|
|
||||||
|
|
@ -241,6 +302,7 @@ async def index_google_gmail_messages(
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_composio_connector:
|
||||||
gmail_connector = GoogleGmailConnector(
|
gmail_connector = GoogleGmailConnector(
|
||||||
credentials, session, user_id, connector_id
|
credentials, session, user_id, connector_id
|
||||||
)
|
)
|
||||||
|
|
@ -254,6 +316,55 @@ async def index_google_gmail_messages(
|
||||||
f"Fetching emails for connector {connector_id} "
|
f"Fetching emails for connector {connector_id} "
|
||||||
f"from {calculated_start_date} to {calculated_end_date}"
|
f"from {calculated_start_date} to {calculated_end_date}"
|
||||||
)
|
)
|
||||||
|
if is_composio_connector:
|
||||||
|
query_parts = []
|
||||||
|
if calculated_start_date:
|
||||||
|
query_parts.append(f"after:{calculated_start_date.replace('-', '/')}")
|
||||||
|
if calculated_end_date:
|
||||||
|
query_parts.append(f"before:{calculated_end_date.replace('-', '/')}")
|
||||||
|
query = " ".join(query_parts)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
page_token = None
|
||||||
|
error = None
|
||||||
|
while len(messages) < max_messages:
|
||||||
|
page_size = min(50, max_messages - len(messages))
|
||||||
|
(
|
||||||
|
page_messages,
|
||||||
|
page_token,
|
||||||
|
_estimate,
|
||||||
|
page_error,
|
||||||
|
) = await composio_service.get_gmail_messages(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
query=query,
|
||||||
|
max_results=page_size,
|
||||||
|
page_token=page_token,
|
||||||
|
)
|
||||||
|
if page_error:
|
||||||
|
error = page_error
|
||||||
|
break
|
||||||
|
for page_message in page_messages:
|
||||||
|
message_id = (
|
||||||
|
page_message.get("id")
|
||||||
|
or page_message.get("message_id")
|
||||||
|
or page_message.get("messageId")
|
||||||
|
)
|
||||||
|
if message_id:
|
||||||
|
(
|
||||||
|
detail,
|
||||||
|
detail_error,
|
||||||
|
) = await composio_service.get_gmail_message_detail(
|
||||||
|
connected_account_id=connected_account_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
if not detail_error and isinstance(detail, dict):
|
||||||
|
page_message = detail
|
||||||
|
messages.append(_normalize_composio_gmail_message(page_message))
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
else:
|
||||||
messages, error = await gmail_connector.get_recent_messages(
|
messages, error = await gmail_connector.get_recent_messages(
|
||||||
max_results=max_messages,
|
max_results=max_messages,
|
||||||
start_date=calculated_start_date,
|
start_date=calculated_start_date,
|
||||||
|
|
@ -326,7 +437,12 @@ async def index_google_gmail_messages(
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
markdown_content = gmail_connector.format_message_to_markdown(message)
|
if is_composio_connector:
|
||||||
|
markdown_content = _format_gmail_message_to_markdown(message)
|
||||||
|
else:
|
||||||
|
markdown_content = gmail_connector.format_message_to_markdown(
|
||||||
|
message
|
||||||
|
)
|
||||||
if not markdown_content.strip():
|
if not markdown_content.strip():
|
||||||
logger.warning(f"Skipping message with no content: {message_id}")
|
logger.warning(f"Skipping message with no content: {message_id}")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "surf-new-backend"
|
name = "surf-new-backend"
|
||||||
version = "0.0.19"
|
version = "0.0.20"
|
||||||
description = "SurfSense Backend"
|
description = "SurfSense Backend"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
@ -71,11 +71,11 @@ dependencies = [
|
||||||
"langchain>=1.2.13",
|
"langchain>=1.2.13",
|
||||||
"langgraph>=1.1.3",
|
"langgraph>=1.1.3",
|
||||||
"langchain-community>=0.4.1",
|
"langchain-community>=0.4.1",
|
||||||
"deepagents>=0.4.12",
|
|
||||||
"stripe>=15.0.0",
|
"stripe>=15.0.0",
|
||||||
"azure-ai-documentintelligence>=1.0.2",
|
"azure-ai-documentintelligence>=1.0.2",
|
||||||
"litellm>=1.83.7",
|
"litellm>=1.83.7",
|
||||||
"langchain-litellm>=0.6.4",
|
"langchain-litellm>=0.6.4",
|
||||||
|
"deepagents>=0.4.12,<0.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|
|
||||||
558
surfsense_backend/scripts/verify_chat_image_capability.py
Normal file
558
surfsense_backend/scripts/verify_chat_image_capability.py
Normal file
|
|
@ -0,0 +1,558 @@
|
||||||
|
"""End-to-end smoke test for vision / image config wiring.
|
||||||
|
|
||||||
|
Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
|
||||||
|
exercises every chat / vision / image-generation config + the OpenRouter
|
||||||
|
dynamic catalog. For each config the script:
|
||||||
|
|
||||||
|
1. Reports the resolver classification (catalog-allow vs strict-block).
|
||||||
|
2. Optionally fires a tiny live API call against the provider:
|
||||||
|
- Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
|
||||||
|
``"reply with one word: ok"``.
|
||||||
|
- Vision configs: same, against the dedicated vision router pool.
|
||||||
|
- Image-gen configs: ``litellm.aimage_generation`` with a single tiny
|
||||||
|
prompt and ``n=1``.
|
||||||
|
- OpenRouter integration: samples one chat, one vision, one image-gen
|
||||||
|
model from the dynamically fetched catalog.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python -m scripts.verify_chat_image_capability # capability + connectivity
|
||||||
|
python -m scripts.verify_chat_image_capability --no-live # capability resolver only
|
||||||
|
|
||||||
|
The script is meant to be runnable from the repository root or from
|
||||||
|
``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
|
||||||
|
end so it's usable as a CI smoke check too.
|
||||||
|
|
||||||
|
Live-mode caveat: each successful call costs a small amount of provider
|
||||||
|
credit (a few tokens or one tiny generated image per config). The
|
||||||
|
default size for image generation is ``1024x1024`` because Azure
|
||||||
|
GPT-image deployments reject smaller sizes; OpenRouter image-gen models
|
||||||
|
generally accept the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Bootstrap the surfsense_backend package on sys.path so the script runs
|
||||||
|
# from the repo root or from `surfsense_backend/` interchangeably.
|
||||||
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_BACKEND_ROOT = os.path.dirname(_HERE)
|
||||||
|
if _BACKEND_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _BACKEND_ROOT)
|
||||||
|
|
||||||
|
import litellm # noqa: E402
|
||||||
|
|
||||||
|
from app.config import config # noqa: E402
|
||||||
|
from app.services.openrouter_integration_service import ( # noqa: E402
|
||||||
|
_OPENROUTER_DYNAMIC_MARKER,
|
||||||
|
OpenRouterIntegrationService,
|
||||||
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base # noqa: E402
|
||||||
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
|
derive_supports_image_input,
|
||||||
|
is_known_text_only_chat_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARNING,
|
||||||
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||||
|
)
|
||||||
|
# Quiet down LiteLLM's verbose router/cost logs so the script output is
|
||||||
|
# scannable.
|
||||||
|
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("litellm").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# 1x1 transparent PNG — used as the cheapest possible vision payload.
|
||||||
|
_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Result accounting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProbeResult:
|
||||||
|
label: str
|
||||||
|
surface: str
|
||||||
|
config_id: int | str
|
||||||
|
capability_ok: bool | None = None
|
||||||
|
capability_note: str = ""
|
||||||
|
live_ok: bool | None = None
|
||||||
|
live_note: str = ""
|
||||||
|
duration_s: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Report:
|
||||||
|
results: list[ProbeResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add(self, r: ProbeResult) -> None:
|
||||||
|
self.results.append(r)
|
||||||
|
|
||||||
|
def render(self) -> int:
|
||||||
|
passed = failed = skipped = 0
|
||||||
|
print()
|
||||||
|
print("=" * 92)
|
||||||
|
print(
|
||||||
|
f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
|
||||||
|
)
|
||||||
|
print("-" * 92)
|
||||||
|
for r in self.results:
|
||||||
|
|
||||||
|
def _flag(value: bool | None) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "skip"
|
||||||
|
return "ok" if value else "fail"
|
||||||
|
|
||||||
|
cap = _flag(r.capability_ok)
|
||||||
|
live = _flag(r.live_ok)
|
||||||
|
if r.capability_ok is False or r.live_ok is False:
|
||||||
|
failed += 1
|
||||||
|
elif r.capability_ok is None and r.live_ok is None:
|
||||||
|
skipped += 1
|
||||||
|
else:
|
||||||
|
passed += 1
|
||||||
|
print(
|
||||||
|
f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
|
||||||
|
f"{r.duration_s:>5.2f}s {r.label}"
|
||||||
|
)
|
||||||
|
if r.capability_note:
|
||||||
|
print(f" cap: {r.capability_note}")
|
||||||
|
if r.live_note:
|
||||||
|
print(f" live: {r.live_note}")
|
||||||
|
print("-" * 92)
|
||||||
|
print(
|
||||||
|
f"Total: {passed} ok / {failed} fail / {skipped} skip "
|
||||||
|
f"(of {len(self.results)} probes)"
|
||||||
|
)
|
||||||
|
print("=" * 92)
|
||||||
|
return failed
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Capability probes (no network)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
||||||
|
"""For chat configs the catalog flag is *expected* True (vision-capable
|
||||||
|
pool). The probe reports both the resolver value and the strict
|
||||||
|
safety-net value to surface any drift between them."""
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||||
|
)
|
||||||
|
cap = derive_supports_image_input(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
block = is_known_text_only_chat_model(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
note = f"derive={cap} strict_block={block}"
|
||||||
|
if not cap and not block:
|
||||||
|
# Resolver said False but strict gate is also False — that means
|
||||||
|
# OR modalities published [text] explicitly. Surface it.
|
||||||
|
note += " (OR modality says text-only)"
|
||||||
|
# We accept a True derive *or* (False derive AND False block) as
|
||||||
|
# 'capability ok' — either way, the streaming task will flow through.
|
||||||
|
ok = cap or not block
|
||||||
|
return ok, note
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chat_model_string(cfg: dict) -> str:
|
||||||
|
if cfg.get("custom_provider"):
|
||||||
|
return f"{cfg['custom_provider']}/{cfg['model_name']}"
|
||||||
|
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||||
|
|
||||||
|
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||||
|
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||||
|
)
|
||||||
|
return f"{prefix}/{cfg['model_name']}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Live probes (network calls)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
||||||
|
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
|
||||||
|
model_string = _build_chat_model_string(cfg)
|
||||||
|
api_base = resolve_api_base(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=model_string.split("/", 1)[0],
|
||||||
|
config_api_base=cfg.get("api_base") or None,
|
||||||
|
)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model_string,
|
||||||
|
"api_key": cfg.get("api_key"),
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "reply with one word: ok"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": _TINY_PNG_DATA_URL},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 16,
|
||||||
|
"timeout": 60,
|
||||||
|
}
|
||||||
|
if api_base:
|
||||||
|
kwargs["api_base"] = api_base
|
||||||
|
if cfg.get("litellm_params"):
|
||||||
|
# Strip pricing keys — they're tracking-only and confuse some
|
||||||
|
# provider validators (e.g. azure/openai reject unknown kwargs
|
||||||
|
# in strict mode).
|
||||||
|
merged = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(cfg["litellm_params"]).items()
|
||||||
|
if k
|
||||||
|
not in {
|
||||||
|
"input_cost_per_token",
|
||||||
|
"output_cost_per_token",
|
||||||
|
"input_cost_per_pixel",
|
||||||
|
"output_cost_per_pixel",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kwargs.update(merged)
|
||||||
|
try:
|
||||||
|
resp = await litellm.acompletion(**kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
return False, f"{type(exc).__name__}: {exc}"
|
||||||
|
text = resp.choices[0].message.content if resp.choices else ""
|
||||||
|
return True, f"got reply ({(text or '').strip()[:40]!r})"
|
||||||
|
|
||||||
|
|
||||||
|
# Gemini image models occasionally return zero-length ``data`` for the
|
||||||
|
# minimal "red dot on white" prompt (provider-side safety / empty-output
|
||||||
|
# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
|
||||||
|
# the request itself succeeds). Use a more naturalistic prompt and
|
||||||
|
# retry once with a different one before giving up.
|
||||||
|
_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
|
||||||
|
"A simple icon of a coffee cup, flat illustration",
|
||||||
|
"A small green leaf on a white background",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
||||||
|
"""Generate one tiny image to verify the deployment is reachable."""
|
||||||
|
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||||
|
|
||||||
|
if cfg.get("custom_provider"):
|
||||||
|
prefix = cfg["custom_provider"]
|
||||||
|
else:
|
||||||
|
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||||
|
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||||
|
)
|
||||||
|
model_string = f"{prefix}/{cfg['model_name']}"
|
||||||
|
api_base = resolve_api_base(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=prefix,
|
||||||
|
config_api_base=cfg.get("api_base") or None,
|
||||||
|
)
|
||||||
|
base_kwargs: dict[str, Any] = {
|
||||||
|
"model": model_string,
|
||||||
|
"api_key": cfg.get("api_key"),
|
||||||
|
"n": 1,
|
||||||
|
"size": "1024x1024",
|
||||||
|
"timeout": 120,
|
||||||
|
}
|
||||||
|
if api_base:
|
||||||
|
base_kwargs["api_base"] = api_base
|
||||||
|
if cfg.get("api_version"):
|
||||||
|
base_kwargs["api_version"] = cfg["api_version"]
|
||||||
|
if cfg.get("litellm_params"):
|
||||||
|
base_kwargs.update(
|
||||||
|
{
|
||||||
|
k: v
|
||||||
|
for k, v in dict(cfg["litellm_params"]).items()
|
||||||
|
if k
|
||||||
|
not in {
|
||||||
|
"input_cost_per_token",
|
||||||
|
"output_cost_per_token",
|
||||||
|
"input_cost_per_pixel",
|
||||||
|
"output_cost_per_pixel",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
last_note = ""
|
||||||
|
for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
|
||||||
|
try:
|
||||||
|
resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
last_note = f"{type(exc).__name__}: {exc}"
|
||||||
|
continue
|
||||||
|
data_count = len(getattr(resp, "data", None) or [])
|
||||||
|
if data_count > 0:
|
||||||
|
return True, (
|
||||||
|
f"received {data_count} image(s) on attempt {attempt} "
|
||||||
|
f"(prompt={prompt!r})"
|
||||||
|
)
|
||||||
|
last_note = (
|
||||||
|
f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
|
||||||
|
)
|
||||||
|
return False, last_note
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Probe drivers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _is_or_dynamic(cfg: dict) -> bool:
|
||||||
|
return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
|
||||||
|
|
||||||
|
|
||||||
|
async def probe_chat_configs(report: Report, *, live: bool) -> None:
|
||||||
|
print("\n[chat configs from global_llm_configs (YAML-static)]")
|
||||||
|
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||||
|
# Skip OR dynamic entries here — handled in the OR section so
|
||||||
|
# the YAML / OR split stays clear in the report.
|
||||||
|
if _is_or_dynamic(cfg):
|
||||||
|
continue
|
||||||
|
result = ProbeResult(
|
||||||
|
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||||
|
surface="chat-yaml",
|
||||||
|
config_id=cfg.get("id"),
|
||||||
|
)
|
||||||
|
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||||
|
result.capability_ok = cap_ok
|
||||||
|
result.capability_note = cap_note
|
||||||
|
if live:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ok, note = await _live_chat_image_call(cfg)
|
||||||
|
result.live_ok = ok
|
||||||
|
result.live_note = note
|
||||||
|
result.duration_s = time.perf_counter() - t0
|
||||||
|
report.add(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def probe_vision_configs(report: Report, *, live: bool) -> None:
|
||||||
|
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
|
||||||
|
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||||
|
if _is_or_dynamic(cfg):
|
||||||
|
continue
|
||||||
|
result = ProbeResult(
|
||||||
|
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||||
|
surface="vision",
|
||||||
|
config_id=cfg.get("id"),
|
||||||
|
)
|
||||||
|
# For vision configs, capability is implied — they're in the
|
||||||
|
# dedicated vision pool. Run the same resolver to flag any
|
||||||
|
# surprise disagreement.
|
||||||
|
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||||
|
result.capability_ok = cap_ok
|
||||||
|
result.capability_note = cap_note
|
||||||
|
if live:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ok, note = await _live_chat_image_call(cfg)
|
||||||
|
result.live_ok = ok
|
||||||
|
result.live_note = note
|
||||||
|
result.duration_s = time.perf_counter() - t0
|
||||||
|
report.add(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
||||||
|
print(
|
||||||
|
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
|
||||||
|
)
|
||||||
|
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||||
|
if _is_or_dynamic(cfg):
|
||||||
|
continue
|
||||||
|
result = ProbeResult(
|
||||||
|
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||||
|
surface="image-gen",
|
||||||
|
config_id=cfg.get("id"),
|
||||||
|
)
|
||||||
|
# Image gen configs don't have a "supports_image_input" flag;
|
||||||
|
# the catalog tracks output, not input. Mark capability as None
|
||||||
|
# (skip) for the report.
|
||||||
|
if live:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ok, note = await _live_image_gen_call(cfg)
|
||||||
|
result.live_ok = ok
|
||||||
|
result.live_note = note
|
||||||
|
result.duration_s = time.perf_counter() - t0
|
||||||
|
report.add(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
||||||
|
"""Sample one chat (vision-capable), one vision, one image-gen model
|
||||||
|
from the live OpenRouter catalogue. Doesn't iterate the full pool
|
||||||
|
(would be hundreds of probes); just validates the integration end-
|
||||||
|
to-end on a representative model from each surface."""
|
||||||
|
print("\n[OpenRouter integration: sampled probes]")
|
||||||
|
settings = config.OPENROUTER_INTEGRATION_SETTINGS
|
||||||
|
if not settings:
|
||||||
|
report.add(
|
||||||
|
ProbeResult(
|
||||||
|
label="OpenRouter integration",
|
||||||
|
surface="openrouter",
|
||||||
|
config_id="settings",
|
||||||
|
capability_ok=None,
|
||||||
|
capability_note="openrouter_integration disabled in YAML — skipping",
|
||||||
|
live_ok=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService.get_instance()
|
||||||
|
or_chat = [
|
||||||
|
c
|
||||||
|
for c in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
|
||||||
|
]
|
||||||
|
or_vision = [
|
||||||
|
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||||
|
]
|
||||||
|
or_image_gen = [
|
||||||
|
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pick one representative per provider family per surface so a single
|
||||||
|
# broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
|
||||||
|
# surfaces independently of the others. Each needle matches the
|
||||||
|
# OpenRouter ``model_name`` prefix; the first match wins.
|
||||||
|
def _pick_first(pool: list[dict], needle: str) -> dict | None:
|
||||||
|
for c in pool:
|
||||||
|
if (c.get("model_name") or "").lower().startswith(needle):
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
chat_picks = [
|
||||||
|
("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
|
||||||
|
("or-chat", _pick_first(or_chat, "anthropic/claude")),
|
||||||
|
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
|
||||||
|
]
|
||||||
|
vision_picks = [
|
||||||
|
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
|
||||||
|
("or-vision", _pick_first(or_vision, "anthropic/claude")),
|
||||||
|
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
|
||||||
|
]
|
||||||
|
image_picks = [
|
||||||
|
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
|
||||||
|
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
|
||||||
|
# / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
|
||||||
|
# the actual prefix.
|
||||||
|
("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
|
||||||
|
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
for surface, picked in chat_picks + vision_picks + image_picks:
|
||||||
|
if not picked:
|
||||||
|
report.add(
|
||||||
|
ProbeResult(
|
||||||
|
label=f"<no candidate for {surface}>",
|
||||||
|
surface=surface,
|
||||||
|
config_id="-",
|
||||||
|
capability_ok=None,
|
||||||
|
capability_note="no candidate found in OR catalog",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
runner = (
|
||||||
|
_live_image_gen_call if surface == "or-image" else _live_chat_image_call
|
||||||
|
)
|
||||||
|
result = ProbeResult(
|
||||||
|
label=str(picked.get("model_name")),
|
||||||
|
surface=surface,
|
||||||
|
config_id=picked.get("id"),
|
||||||
|
)
|
||||||
|
if surface != "or-image":
|
||||||
|
cap_ok, cap_note = _probe_chat_capability(picked)
|
||||||
|
result.capability_ok = cap_ok
|
||||||
|
result.capability_note = cap_note
|
||||||
|
if live:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ok, note = await runner(picked)
|
||||||
|
result.live_ok = ok
|
||||||
|
result.live_note = note
|
||||||
|
result.duration_s = time.perf_counter() - t0
|
||||||
|
report.add(result)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args: argparse.Namespace) -> int:
|
||||||
|
print("Loaded global configs:")
|
||||||
|
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
|
||||||
|
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
|
||||||
|
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
|
||||||
|
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
|
||||||
|
|
||||||
|
# Initialize the OpenRouter integration so the catalog is populated
|
||||||
|
# (this is what main.py does at startup). It's idempotent.
|
||||||
|
if config.OPENROUTER_INTEGRATION_SETTINGS:
|
||||||
|
try:
|
||||||
|
from app.config import initialize_openrouter_integration
|
||||||
|
|
||||||
|
initialize_openrouter_integration()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f" WARNING: OpenRouter integration init failed: {exc}")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
report = Report()
|
||||||
|
if not args.skip_chat:
|
||||||
|
await probe_chat_configs(report, live=args.live)
|
||||||
|
if not args.skip_vision:
|
||||||
|
await probe_vision_configs(report, live=args.live)
|
||||||
|
if not args.skip_image_gen:
|
||||||
|
await probe_image_gen_configs(report, live=args.live)
|
||||||
|
if not args.skip_openrouter:
|
||||||
|
await probe_openrouter_catalog(report, live=args.live)
|
||||||
|
|
||||||
|
failed = report.render()
|
||||||
|
return 1 if failed else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-live",
|
||||||
|
dest="live",
|
||||||
|
action="store_false",
|
||||||
|
help="Skip live API calls — capability resolver only.",
|
||||||
|
)
|
||||||
|
parser.set_defaults(live=True)
|
||||||
|
parser.add_argument("--skip-chat", action="store_true")
|
||||||
|
parser.add_argument("--skip-vision", action="store_true")
|
||||||
|
parser.add_argument("--skip-image-gen", action="store_true")
|
||||||
|
parser.add_argument("--skip-openrouter", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = _parse_args()
|
||||||
|
sys.exit(asyncio.run(main(args)))
|
||||||
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
|
|
@ -0,0 +1,268 @@
|
||||||
|
"""Regression tests for the compiled-agent cache.
|
||||||
|
|
||||||
|
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
|
||||||
|
build-failure non-caching) and the cache-key signature helpers that
|
||||||
|
``create_surfsense_deep_agent`` relies on. The integration with
|
||||||
|
``create_surfsense_deep_agent`` is covered separately by the streaming
|
||||||
|
contract tests; this module focuses on the primitives so a regression
|
||||||
|
in the cache implementation is caught before it reaches the agent
|
||||||
|
factory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.agent_cache import (
|
||||||
|
flags_signature,
|
||||||
|
reload_for_tests,
|
||||||
|
stable_hash,
|
||||||
|
system_prompt_hash,
|
||||||
|
tools_signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# stable_hash + signature helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_hash_is_deterministic_across_calls() -> None:
|
||||||
|
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||||
|
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_hash_changes_when_any_part_changes() -> None:
|
||||||
|
base = stable_hash("v1", 42, "thread-9")
|
||||||
|
assert stable_hash("v1", 42, "thread-10") != base
|
||||||
|
assert stable_hash("v2", 42, "thread-9") != base
|
||||||
|
assert stable_hash("v1", 43, "thread-9") != base
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
|
||||||
|
"""Two tool lists with the same surface must hash identically.
|
||||||
|
|
||||||
|
The cache key MUST NOT change when the underlying ``BaseTool``
|
||||||
|
instances are different Python objects (a fresh request constructs
|
||||||
|
fresh tool instances every time). Hashing on ``(name, description)``
|
||||||
|
keeps the cache hot across requests with identical tool surfaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeTool:
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
|
||||||
|
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
|
||||||
|
sig_a = tools_signature(
|
||||||
|
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||||
|
)
|
||||||
|
sig_b = tools_signature(
|
||||||
|
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||||
|
)
|
||||||
|
assert sig_a == sig_b, "tool order must not affect the signature"
|
||||||
|
|
||||||
|
# Adding a tool rotates the key.
|
||||||
|
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
|
||||||
|
sig_c = tools_signature(
|
||||||
|
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||||
|
)
|
||||||
|
assert sig_c != sig_a
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_signature_rotates_when_connector_set_changes() -> None:
|
||||||
|
@dataclass
|
||||||
|
class FakeTool:
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
tools = [FakeTool("a", "x")]
|
||||||
|
base = tools_signature(
|
||||||
|
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||||
|
)
|
||||||
|
added = tools_signature(
|
||||||
|
tools,
|
||||||
|
available_connectors=["NOTION", "SLACK"],
|
||||||
|
available_document_types=["FILE"],
|
||||||
|
)
|
||||||
|
assert base != added, "adding a connector must rotate the cache key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flags_signature_changes_when_flag_flips() -> None:
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Flags:
|
||||||
|
a: bool = True
|
||||||
|
b: bool = False
|
||||||
|
|
||||||
|
base = flags_signature(Flags())
|
||||||
|
flipped = flags_signature(Flags(b=True))
|
||||||
|
assert base != flipped
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_hash_is_stable_and_distinct() -> None:
|
||||||
|
p1 = "You are a helpful assistant."
|
||||||
|
p2 = "You are a helpful assistant!" # one-character delta
|
||||||
|
assert system_prompt_hash(p1) == system_prompt_hash(p1)
|
||||||
|
assert system_prompt_hash(p1) != system_prompt_hash(p2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||||
|
builds = 0
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
nonlocal builds
|
||||||
|
builds += 1
|
||||||
|
return object()
|
||||||
|
|
||||||
|
a = await cache.get_or_build("k", builder=builder)
|
||||||
|
b = await cache.get_or_build("k", builder=builder)
|
||||||
|
assert a is b, "cache must return the SAME object across hits"
|
||||||
|
assert builds == 1, "builder must run exactly once"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_different_keys_get_different_instances() -> None:
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
a = await cache.get_or_build("k1", builder=builder)
|
||||||
|
b = await cache.get_or_build("k2", builder=builder)
|
||||||
|
assert a is not b
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_stale_entries_get_rebuilt() -> None:
|
||||||
|
# ttl=0 means every read sees the entry as immediately stale.
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
|
||||||
|
builds = 0
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
nonlocal builds
|
||||||
|
builds += 1
|
||||||
|
return object()
|
||||||
|
|
||||||
|
a = await cache.get_or_build("k", builder=builder)
|
||||||
|
b = await cache.get_or_build("k", builder=builder)
|
||||||
|
assert a is not b, "stale entry must rebuild a fresh instance"
|
||||||
|
assert builds == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_evicts_lru_when_full() -> None:
|
||||||
|
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
a = await cache.get_or_build("a", builder=builder)
|
||||||
|
_ = await cache.get_or_build("b", builder=builder)
|
||||||
|
# Re-touch "a" so "b" is now the LRU victim.
|
||||||
|
a_again = await cache.get_or_build("a", builder=builder)
|
||||||
|
assert a_again is a
|
||||||
|
# Inserting "c" should evict "b" (LRU), not "a".
|
||||||
|
_ = await cache.get_or_build("c", builder=builder)
|
||||||
|
assert cache.stats()["size"] == 2
|
||||||
|
|
||||||
|
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
|
||||||
|
a_hit = await cache.get_or_build("a", builder=builder)
|
||||||
|
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
|
||||||
|
"""Two concurrent get_or_build calls on the same key must share one builder."""
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||||
|
build_started = asyncio.Event()
|
||||||
|
builds = 0
|
||||||
|
|
||||||
|
async def slow_builder() -> object:
|
||||||
|
nonlocal builds
|
||||||
|
builds += 1
|
||||||
|
build_started.set()
|
||||||
|
# Yield control so the second waiter can race against us.
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||||
|
# Wait until the first builder has started, then race a second waiter.
|
||||||
|
await build_started.wait()
|
||||||
|
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||||
|
|
||||||
|
a, b = await asyncio.gather(task_a, task_b)
|
||||||
|
assert a is b, "coalesced waiters must observe the same value"
|
||||||
|
assert builds == 1, "concurrent cold misses must collapse to ONE build"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_does_not_store_failed_builds() -> None:
|
||||||
|
"""A builder that raises must NOT poison the cache.
|
||||||
|
|
||||||
|
The next caller for the same key must run the builder again (not
|
||||||
|
re-raise the cached exception).
|
||||||
|
"""
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
async def flaky_builder() -> object:
|
||||||
|
nonlocal attempts
|
||||||
|
attempts += 1
|
||||||
|
if attempts == 1:
|
||||||
|
raise RuntimeError("transient")
|
||||||
|
return object()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="transient"):
|
||||||
|
await cache.get_or_build("k", builder=flaky_builder)
|
||||||
|
|
||||||
|
# Second call must retry — not re-raise the cached exception.
|
||||||
|
value = await cache.get_or_build("k", builder=flaky_builder)
|
||||||
|
assert value is not None
|
||||||
|
assert attempts == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_invalidate_drops_entry() -> None:
|
||||||
|
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
a = await cache.get_or_build("k", builder=builder)
|
||||||
|
assert cache.invalidate("k") is True
|
||||||
|
b = await cache.get_or_build("k", builder=builder)
|
||||||
|
assert a is not b, "post-invalidation lookup must rebuild"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
|
||||||
|
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
|
||||||
|
|
||||||
|
async def builder() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
await cache.get_or_build("user:1:thread:1", builder=builder)
|
||||||
|
await cache.get_or_build("user:1:thread:2", builder=builder)
|
||||||
|
await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||||
|
|
||||||
|
removed = cache.invalidate_prefix("user:1:")
|
||||||
|
assert removed == 2
|
||||||
|
assert cache.stats()["size"] == 1
|
||||||
|
|
||||||
|
# The user:2 entry must still be hot (no rebuild).
|
||||||
|
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||||
|
assert survivor_value is not None
|
||||||
|
|
@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||||
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||||
"SURFSENSE_ENABLE_OTEL",
|
"SURFSENSE_ENABLE_OTEL",
|
||||||
|
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||||
|
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
|
||||||
]:
|
]:
|
||||||
monkeypatch.delenv(name, raising=False)
|
monkeypatch.delenv(name, raising=False)
|
||||||
|
|
||||||
|
|
||||||
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
_clear_all(monkeypatch)
|
_clear_all(monkeypatch)
|
||||||
flags = reload_for_tests()
|
flags = reload_for_tests()
|
||||||
assert isinstance(flags, AgentFeatureFlags)
|
assert isinstance(flags, AgentFeatureFlags)
|
||||||
assert flags.disable_new_agent_stack is False
|
assert flags.disable_new_agent_stack is False
|
||||||
assert flags.any_new_middleware_enabled() is False
|
assert flags.enable_context_editing is True
|
||||||
|
assert flags.enable_compaction_v2 is True
|
||||||
|
assert flags.enable_retry_after is True
|
||||||
|
assert flags.enable_model_fallback is False
|
||||||
|
assert flags.enable_model_call_limit is True
|
||||||
|
assert flags.enable_tool_call_limit is True
|
||||||
|
assert flags.enable_tool_call_repair is True
|
||||||
|
assert flags.enable_doom_loop is True
|
||||||
|
assert flags.enable_permission is True
|
||||||
|
assert flags.enable_busy_mutex is True
|
||||||
|
assert flags.enable_llm_tool_selector is False
|
||||||
|
assert flags.enable_skills is True
|
||||||
|
assert flags.enable_specialized_subagents is True
|
||||||
|
assert flags.enable_kb_planner_runnable is True
|
||||||
|
assert flags.enable_action_log is True
|
||||||
|
assert flags.enable_revert_route is True
|
||||||
|
assert flags.enable_stream_parity_v2 is True
|
||||||
|
assert flags.enable_plugin_loader is False
|
||||||
|
assert flags.enable_otel is False
|
||||||
|
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||||
|
# ``db_session`` refactor landed). The companion gp-subagent share
|
||||||
|
# flag stays default-off pending data on cold-miss frequency.
|
||||||
|
assert flags.enable_agent_cache is True
|
||||||
|
assert flags.enable_agent_cache_share_gp_subagent is False
|
||||||
|
assert flags.any_new_middleware_enabled() is True
|
||||||
|
|
||||||
|
|
||||||
def test_master_kill_switch_overrides_individual_flags(
|
def test_master_kill_switch_overrides_individual_flags(
|
||||||
|
|
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
||||||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||||
|
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||||
}
|
}
|
||||||
|
|
||||||
# `enable_otel` is intentionally orthogonal — it does NOT count toward
|
|
||||||
# ``any_new_middleware_enabled`` because OTel is observability-only and
|
|
||||||
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
|
|
||||||
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
|
|
||||||
|
|
||||||
for attr, env_name in flag_to_env.items():
|
for attr, env_name in flag_to_env.items():
|
||||||
_clear_all(monkeypatch)
|
_clear_all(monkeypatch)
|
||||||
monkeypatch.setenv(env_name, "true")
|
monkeypatch.setenv(env_name, "false")
|
||||||
flags = reload_for_tests()
|
flags = reload_for_tests()
|
||||||
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
|
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
|
||||||
if attr in counts_toward_middleware:
|
|
||||||
assert flags.any_new_middleware_enabled() is True
|
|
||||||
else:
|
|
||||||
assert flags.any_new_middleware_enabled() is False
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,344 @@
|
||||||
|
"""Tests for ``FlattenSystemMessageMiddleware``.
|
||||||
|
|
||||||
|
The middleware exists to defend against Anthropic's "Found 5 cache_control
|
||||||
|
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
|
||||||
|
the system message and the OpenRouter→Anthropic adapter redistributes
|
||||||
|
``cache_control`` across all of them. The flattening collapses every
|
||||||
|
all-text system content list to a single string before the LLM call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.flatten_system import (
|
||||||
|
FlattenSystemMessageMiddleware,
|
||||||
|
_flatten_text_blocks,
|
||||||
|
_flattened_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _flatten_text_blocks — pure helper, the heart of the middleware.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlattenTextBlocks:
|
||||||
|
def test_joins_text_blocks_with_double_newline(self) -> None:
|
||||||
|
blocks = [
|
||||||
|
{"type": "text", "text": "<surfsense base>"},
|
||||||
|
{"type": "text", "text": "<filesystem section>"},
|
||||||
|
{"type": "text", "text": "<skills section>"},
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
_flatten_text_blocks(blocks)
|
||||||
|
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_handles_single_text_block(self) -> None:
|
||||||
|
blocks = [{"type": "text", "text": "only one"}]
|
||||||
|
assert _flatten_text_blocks(blocks) == "only one"
|
||||||
|
|
||||||
|
def test_handles_empty_list(self) -> None:
|
||||||
|
assert _flatten_text_blocks([]) == ""
|
||||||
|
|
||||||
|
def test_passes_through_bare_string_blocks(self) -> None:
|
||||||
|
# LangChain content can mix bare strings and dict blocks.
|
||||||
|
blocks = ["raw string", {"type": "text", "text": "dict block"}]
|
||||||
|
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
|
||||||
|
|
||||||
|
def test_returns_none_for_image_block(self) -> None:
|
||||||
|
# System messages with images are rare — but we never want to
|
||||||
|
# silently lose the image payload by joining as text.
|
||||||
|
blocks = [
|
||||||
|
{"type": "text", "text": "look at this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
|
||||||
|
]
|
||||||
|
assert _flatten_text_blocks(blocks) is None
|
||||||
|
|
||||||
|
def test_returns_none_for_non_dict_non_str_block(self) -> None:
|
||||||
|
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
|
||||||
|
assert _flatten_text_blocks(blocks) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_text_field_missing(self) -> None:
|
||||||
|
blocks = [{"type": "text"}] # no ``text`` key
|
||||||
|
assert _flatten_text_blocks(blocks) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_text_is_not_string(self) -> None:
|
||||||
|
blocks = [{"type": "text", "text": ["nested", "list"]}]
|
||||||
|
assert _flatten_text_blocks(blocks) is None
|
||||||
|
|
||||||
|
def test_drops_cache_control_from_inner_blocks(self) -> None:
|
||||||
|
# The whole point: existing cache_control on inner blocks is
|
||||||
|
# discarded so LiteLLM's ``cache_control_injection_points`` can
|
||||||
|
# re-attach exactly one breakpoint after flattening.
|
||||||
|
blocks = [
|
||||||
|
{"type": "text", "text": "first"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "second",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
flattened = _flatten_text_blocks(blocks)
|
||||||
|
assert flattened == "first\n\nsecond"
|
||||||
|
assert "cache_control" not in flattened # type: ignore[operator]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _flattened_request — decides when to override and when to no-op.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(system_message: SystemMessage | None) -> Any:
|
||||||
|
"""Build a minimal ModelRequest stub. We only need .system_message
|
||||||
|
and .override(system_message=...) — the middleware never touches
|
||||||
|
other fields.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.system_message = system_message
|
||||||
|
|
||||||
|
def override(**kwargs: Any) -> Any:
|
||||||
|
new_request = MagicMock()
|
||||||
|
new_request.system_message = kwargs.get(
|
||||||
|
"system_message", request.system_message
|
||||||
|
)
|
||||||
|
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
|
||||||
|
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
|
||||||
|
return new_request
|
||||||
|
|
||||||
|
request.override = override
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlattenedRequest:
|
||||||
|
def test_collapses_multi_block_system_to_string(self) -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "<base>"},
|
||||||
|
{"type": "text", "text": "<todo>"},
|
||||||
|
{"type": "text", "text": "<filesystem>"},
|
||||||
|
{"type": "text", "text": "<skills>"},
|
||||||
|
{"type": "text", "text": "<subagents>"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
|
||||||
|
assert flattened is not None
|
||||||
|
assert isinstance(flattened.system_message, SystemMessage)
|
||||||
|
assert flattened.system_message.content == (
|
||||||
|
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_op_for_string_content(self) -> None:
|
||||||
|
sys = SystemMessage(content="already a string")
|
||||||
|
request = _make_request(sys)
|
||||||
|
assert _flattened_request(request) is None
|
||||||
|
|
||||||
|
def test_no_op_for_single_block_list(self) -> None:
|
||||||
|
# One block already produces one breakpoint — no need to flatten.
|
||||||
|
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
|
||||||
|
request = _make_request(sys)
|
||||||
|
assert _flattened_request(request) is None
|
||||||
|
|
||||||
|
def test_no_op_when_system_message_missing(self) -> None:
|
||||||
|
request = _make_request(None)
|
||||||
|
assert _flattened_request(request) is None
|
||||||
|
|
||||||
|
def test_no_op_when_list_contains_non_text_block(self) -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "look"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
assert _flattened_request(request) is None
|
||||||
|
|
||||||
|
def test_preserves_additional_kwargs_and_metadata(self) -> None:
|
||||||
|
# Defensive: nothing in the current chain sets these on a system
|
||||||
|
# message, but losing them silently when something does in the
|
||||||
|
# future would be a regression. ``name`` in particular is the only
|
||||||
|
# ``additional_kwargs`` field that ChatLiteLLM's
|
||||||
|
# ``_convert_message_to_dict`` propagates onto the wire.
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "a"},
|
||||||
|
{"type": "text", "text": "b"},
|
||||||
|
],
|
||||||
|
additional_kwargs={"name": "surfsense_system", "x": 1},
|
||||||
|
response_metadata={"tokens": 42},
|
||||||
|
)
|
||||||
|
sys.id = "sys-msg-1"
|
||||||
|
request = _make_request(sys)
|
||||||
|
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
assert flattened is not None
|
||||||
|
assert flattened.system_message.content == "a\n\nb"
|
||||||
|
assert flattened.system_message.additional_kwargs == {
|
||||||
|
"name": "surfsense_system",
|
||||||
|
"x": 1,
|
||||||
|
}
|
||||||
|
assert flattened.system_message.response_metadata == {"tokens": 42}
|
||||||
|
assert flattened.system_message.id == "sys-msg-1"
|
||||||
|
|
||||||
|
def test_idempotent_when_run_twice(self) -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "a"},
|
||||||
|
{"type": "text", "text": "b"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
first = _flattened_request(request)
|
||||||
|
assert first is not None
|
||||||
|
|
||||||
|
# Second pass on the already-flattened request should be a no-op.
|
||||||
|
# We re-wrap in a request stub since the helper inspects
|
||||||
|
# ``request.system_message.content``.
|
||||||
|
second_request = _make_request(first.system_message)
|
||||||
|
assert _flattened_request(second_request) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Middleware integration — verify the handler sees a flattened request.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareWrap:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_passes_flattened_request_to_handler(self) -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "alpha"},
|
||||||
|
{"type": "text", "text": "beta"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def handler(req: Any) -> str:
|
||||||
|
captured["request"] = req
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
mw = FlattenSystemMessageMiddleware()
|
||||||
|
result = await mw.awrap_model_call(request, handler)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert isinstance(captured["request"].system_message, SystemMessage)
|
||||||
|
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_passes_through_when_already_string(self) -> None:
|
||||||
|
sys = SystemMessage(content="just a string")
|
||||||
|
request = _make_request(sys)
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def handler(req: Any) -> str:
|
||||||
|
captured["request"] = req
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
mw = FlattenSystemMessageMiddleware()
|
||||||
|
await mw.awrap_model_call(request, handler)
|
||||||
|
|
||||||
|
# Same request object: no override happened.
|
||||||
|
assert captured["request"] is request
|
||||||
|
|
||||||
|
def test_sync_passes_flattened_request_to_handler(self) -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "alpha"},
|
||||||
|
{"type": "text", "text": "beta"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(req: Any) -> str:
|
||||||
|
captured["request"] = req
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
mw = FlattenSystemMessageMiddleware()
|
||||||
|
result = mw.wrap_model_call(request, handler)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||||
|
|
||||||
|
def test_sync_passes_through_when_no_system_message(self) -> None:
|
||||||
|
request = _make_request(None)
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(req: Any) -> str:
|
||||||
|
captured["request"] = req
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
mw = FlattenSystemMessageMiddleware()
|
||||||
|
mw.wrap_model_call(request, handler)
|
||||||
|
assert captured["request"] is request
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression guard — pin the worst-case shape that triggered the
|
||||||
|
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
|
||||||
|
# downstream cache_control_injection_points can only place 1 breakpoint
|
||||||
|
# on the system message regardless of provider redistribution quirks.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_regression_five_block_system_collapses_to_one_block() -> None:
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
|
||||||
|
{"type": "text", "text": "<TodoListMiddleware section>"},
|
||||||
|
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
|
||||||
|
{"type": "text", "text": "<SkillsMiddleware section>"},
|
||||||
|
{"type": "text", "text": "<SubAgentMiddleware section>"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
|
||||||
|
assert flattened is not None
|
||||||
|
assert isinstance(flattened.system_message.content, str)
|
||||||
|
# The exact join doesn't matter for the cache_control accounting —
|
||||||
|
# only that there is exactly ONE content block when LiteLLM's
|
||||||
|
# AnthropicCacheControlHook later targets ``role: system``.
|
||||||
|
assert "<surfsense base" in flattened.system_message.content
|
||||||
|
assert "<SubAgentMiddleware" in flattened.system_message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_regression_human_message_not_modified() -> None:
|
||||||
|
# Sanity: the middleware MUST NOT touch user messages — only the
|
||||||
|
# system message. Multi-block user content is the path that carries
|
||||||
|
# image attachments and would lose its image_url block on
|
||||||
|
# accidental flatten.
|
||||||
|
sys = SystemMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "a"},
|
||||||
|
{"type": "text", "text": "b"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
user = HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "look at this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
request = _make_request(sys)
|
||||||
|
request.messages = [user]
|
||||||
|
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
assert flattened is not None
|
||||||
|
# System flattened to string …
|
||||||
|
assert isinstance(flattened.system_message.content, str)
|
||||||
|
# … user message is untouched (the helper does not even look at it).
|
||||||
|
assert flattened.messages == [user]
|
||||||
|
assert isinstance(user.content, list)
|
||||||
|
assert len(user.content) == 2
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue