Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction

This commit is contained in:
Anish Sarkar 2026-05-04 12:03:44 +05:30
commit b981b51ab1
176 changed files with 20407 additions and 6258 deletions

View file

@ -1 +1 @@
0.0.19
0.0.20

View file

@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100
# Premium token purchases ($1 per 1M tokens for premium-tier models)
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
# credit; premium turns debit the actual per-call provider cost
# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
# STRIPE_TOKENS_PER_UNIT=1000000
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text)
@ -305,6 +308,24 @@ STT_SERVICE=local/base
# Advanced (optional)
# ------------------------------------------------------------------------------
# New-chat agent feature flags
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_BUSY_MUTEX=true
SURFSENSE_ENABLE_SKILLS=true
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
SURFSENSE_ENABLE_ACTION_LOG=true
SURFSENSE_ENABLE_REVERT_ROUTE=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
# Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m
@ -315,9 +336,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500
# Premium token quota per registered user (default: 5M)
# Only applies to models with billing_tier=premium in global_llm_config.yaml
# PREMIUM_TOKEN_LIMIT=5000000
# Premium credit quota per registered user, in micro-USD (default: $5).
# Premium turns are debited at the actual per-call provider cost reported
# by LiteLLM. Only applies to models with billing_tier=premium.
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
# QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation for the podcast Celery task ($0.20 default).
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation for the video Celery task ($1.00 default).
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API

View file

@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE
# Premium token purchases via Stripe (for premium-tier model usage)
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
# Premium credit purchases via Stripe (for premium-tier model usage).
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
STRIPE_TOKENS_PER_UNIT=1000000
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500
# Premium token quota per registered user (default: 3,000,000)
# Applies only to models with billing_tier=premium in global_llm_config.yaml
PREMIUM_TOKEN_LIMIT=3000000
# Premium credit quota per registered user, in micro-USD
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
# models bill proportionally. Applies only to models with
# billing_tier=premium in global_llm_config.yaml.
PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
# PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD.
# stream_new_chat estimates an upper-bound cost from the model's
# litellm-published per-token rates × the config's quota_reserve_tokens
# and clamps to this value so a misconfigured model can't lock the
# user's whole balance on one call. Default $1.00.
QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation (in micro-USD) for the POST /image-generations
# endpoint. Bypassed for free configs. Default $0.05.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
# Single envelope covers one transcript-generation LLM call. Default $0.20.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation (in micro-USD) used by the video
# presentation Celery task. Covers worst-case fan-out of N slide-scene
# generations + refines. Default $1.00. NOTE: tasks using the override
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API
@ -294,3 +324,30 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
# -----------------------------------------------------------------------------
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
# -----------------------------------------------------------------------------
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
# on a cold turn) is reused across subsequent turns on the same thread,
# collapsing it to a microsecond hash lookup. All connector tools acquire
# their own short-lived DB session per call (Phase 2 refactor) so a cached
# closure is safe to share across requests. Flip OFF only as a last-resort
# rollback if you suspect cache-related staleness.
# SURFSENSE_ENABLE_AGENT_CACHE=true
# Cache capacity (max number of compiled-agent entries kept in memory)
# and TTL per entry (seconds). Working set is typically one entry per
# active thread on this replica; tune up for very large deployments.
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
# -----------------------------------------------------------------------------
# Connector discovery TTL cache (Phase 1.4 perf optimization)
# -----------------------------------------------------------------------------
# Caches the per-search-space "available connectors" + "available document
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
# ORM event listeners auto-invalidate on connector / document inserts,
# updates and deletes — the TTL only bounds staleness for bulk-import
# paths that bypass the ORM. Set to 0 to disable the cache.
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30

View file

@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
COPY pyproject.toml .
COPY uv.lock .
# Install PyTorch based on architecture
RUN if [ "$(uname -m)" = "x86_64" ]; then \
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
else \
pip install --no-cache-dir torch torchvision torchaudio; \
fi
# Install python dependencies
# Install all Python dependencies from uv.lock for deterministic builds.
#
# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
# which lets prod silently drift to newer upstream versions on every rebuild
# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
# Exporting the lock to requirements.txt and feeding it to `uv pip install`
# pins every transitive package to the exact version captured in uv.lock.
#
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
# downloads that the lock-based install immediately replaced. If a specific
# CUDA version is needed (driver compatibility, etc.), wire it through
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
RUN pip install --no-cache-dir uv && \
uv pip install --system --no-cache-dir -e .
uv export --frozen --no-dev --no-hashes --no-emit-project \
--format requirements-txt -o /tmp/requirements.txt && \
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
rm /tmp/requirements.txt
# Set SSL environment variables dynamically
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
# Pre-download Docling models
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
# Install Playwright browsers for web scraping if needed
RUN pip install playwright && \
playwright install chromium --with-deps
# Install Playwright browsers for web scraping (the playwright package itself
# is already installed via uv.lock above)
RUN playwright install chromium --with-deps
# Copy source code
COPY . .
# Install the project itself in editable mode. Dependencies were already
# installed deterministically from uv.lock above, so --no-deps prevents any
# re-resolution that could pull newer versions.
RUN uv pip install --system --no-cache-dir --no-deps -e .
# Copy and set permissions for entrypoint script
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh

View file

@ -0,0 +1,291 @@
"""rename premium token columns to credit micros and add cost_micros to token_usage
Migrates the premium quota system from a flat token counter to a USD-cost
based credit system, where 1 credit = 1 micro-USD ($0.000001).
Column renames (1:1 numerical mapping the prior $1 per 1M tokens Stripe
price means every existing value is already correct in the new unit, no
data transformation needed):
user.premium_tokens_limit -> premium_credit_micros_limit
user.premium_tokens_used -> premium_credit_micros_used
user.premium_tokens_reserved -> premium_credit_micros_reserved
premium_token_purchases.tokens_granted -> credit_micros_granted
New column for cost auditing per turn:
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
The "user" table is in zero_publication's column list (added in 139), so
this migration must drop and recreate the publication with the renamed
column names, otherwise zero-cache will replicate stale column names and
the FE Zero schema will fail to bind.
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
column-not-found errors from a stale catalog snapshot.
Revision ID: 140
Revises: 139
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "140"
down_revision: str | None = "139"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Replicates 139's document column list verbatim — must stay in sync.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Same five live-meter fields as 139, with the renamed column names.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_credit_micros_limit",
"premium_credit_micros_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _column_exists(conn, table: str, column: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = :col"
),
{"tbl": table, "col": column},
).fetchone()
is not None
)
def _build_publication_ddl(
user_cols: list[str],
*,
documents_has_zero_ver: bool,
user_has_zero_ver: bool,
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_col_list_with_meta = user_cols + (
['"_0_version"'] if user_has_zero_ver else []
)
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_col_list_with_meta)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def upgrade() -> None:
conn = op.get_bind()
# ------------------------------------------------------------------
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
# dev environments are safe.
# ------------------------------------------------------------------
if not _column_exists(conn, "token_usage", "cost_micros"):
op.add_column(
"token_usage",
sa.Column(
"cost_micros",
sa.BigInteger(),
nullable=False,
server_default="0",
),
)
# ------------------------------------------------------------------
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
# ------------------------------------------------------------------
if _column_exists(
conn, "premium_token_purchases", "tokens_granted"
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
op.alter_column(
"premium_token_purchases",
"tokens_granted",
new_column_name="credit_micros_granted",
)
# ------------------------------------------------------------------
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
#
# We must drop the publication first (it references the old column
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
# in a transaction block; alembic's outer transaction already holds
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
# ------------------------------------------------------------------
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
# publications require at least the PK to be in the column list,
# which is true for both the old and new shape.
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
# Drop the publication BEFORE renaming columns, otherwise Postgres
# rejects the rename: "cannot drop column ... referenced by
# publication".
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for old, new in (
("premium_tokens_limit", "premium_credit_micros_limit"),
("premium_tokens_used", "premium_credit_micros_used"),
("premium_tokens_reserved", "premium_credit_micros_reserved"),
):
if _column_exists(conn, "user", old) and not _column_exists(
conn, "user", new
):
op.alter_column("user", old, new_column_name=new)
# Update the server_default on the renamed limit column so newly
# inserted users get $5 of credit (== 5_000_000 micros) by
# default. Existing rows are unaffected.
op.alter_column(
"user",
"premium_credit_micros_limit",
server_default="5000000",
)
# Recreate the publication with the new column names.
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
USER_COLS,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
def downgrade() -> None:
"""Revert the rename and drop ``cost_micros``.
Mirrors ``upgrade``: drop the publication, rename columns back, drop
the new column, recreate the publication with the old column list.
Same zero-cache stop/reset runbook applies in reverse.
"""
conn = op.get_bind()
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for new, old in (
("premium_credit_micros_limit", "premium_tokens_limit"),
("premium_credit_micros_used", "premium_tokens_used"),
("premium_credit_micros_reserved", "premium_tokens_reserved"),
):
if _column_exists(conn, "user", new) and not _column_exists(
conn, "user", old
):
op.alter_column("user", new, new_column_name=old)
op.alter_column(
"user",
"premium_tokens_limit",
server_default="5000000",
)
legacy_user_cols = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
legacy_user_cols,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
if _column_exists(
conn, "premium_token_purchases", "credit_micros_granted"
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
op.alter_column(
"premium_token_purchases",
"credit_micros_granted",
new_column_name="tokens_granted",
)
if _column_exists(conn, "token_usage", "cost_micros"):
op.drop_column("token_usage", "cost_micros")

View file

@ -0,0 +1,357 @@
"""TTL-LRU cache for compiled SurfSense deep agents.
Why this exists
---------------
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
turn:
1. Discover connectors & document types from Postgres (~50-200ms)
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
3. Compose the system prompt
4. Construct ~15 middleware instances (CPU)
5. Eagerly compile the general-purpose subagent
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
which builds a second LangGraph + Pydantic schemas ~1.5-2s of pure
CPU work)
6. Compile the outer LangGraph
For a single thread, all six steps produce the SAME object on every turn
unless the user has changed their LLM config, toggled a feature flag,
added a connector, etc. The right answer is to compile ONCE per
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
every subsequent turn on the same thread.
Why a per-thread key (not a global pool)
----------------------------------------
Most middleware in the SurfSense stack captures per-thread state in
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
would silently leak state across users and threads. Keying the cache on
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
turns on the same thread without changing any middleware's behavior.
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
(read via ``runtime.context``) so the cache can collapse to a single
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
then, per-thread keying is the only safe option.
Cache shape
-----------
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
minutes matches a typical chat session). ``maxsize`` (default 256)
caps memory; LRU evicts least-recently-used on overflow.
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
cold misses on the same key wait for the first build instead of
building N times.
* Process-local: this is an in-memory cache. Multi-replica deployments
pay the build cost once per replica per key. That's fine; the working
set per replica is small (one entry per active thread on that replica).
Telemetry
---------
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
* ``hit`` cache hit, microseconds-fast
* ``miss`` first build for this key, includes build duration
* ``stale`` entry was found but expired; rebuilt
* ``evict`` LRU eviction (size-limited)
* ``size`` current cache occupancy at lookup time
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# ---------------------------------------------------------------------------
# Public API: signature helpers (cache key components)
# ---------------------------------------------------------------------------
def stable_hash(*parts: Any) -> str:
"""Compute a deterministic SHA1 of the str repr of ``parts``.
Used for cache key components that need a fixed-width representation
(system prompt, tool list, etc.). SHA1 is fine here this is not a
security boundary, just a content fingerprint.
"""
h = hashlib.sha1(usedforsecurity=False)
for p in parts:
h.update(repr(p).encode("utf-8", errors="replace"))
h.update(b"\x1f") # ASCII unit separator between parts
return h.hexdigest()
def tools_signature(
tools: list[Any] | tuple[Any, ...],
*,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> str:
"""Hash the bound-tool surface for cache-key purposes.
The signature changes whenever:
* A tool is added or removed from the bound list (built-in toggles,
MCP tools loaded for the user changes, gating rules flip, etc.).
* The available connectors / document types for the search space
change (new connector added, last connector removed, new document
type indexed). Because :func:`get_connector_gated_tools` derives
``modified_disabled_tools`` from ``available_connectors``, the
tool surface is technically already covered but we hash the
connector list separately so an empty-list "no tools changed"
situation still rotates the key when, say, the user re-adds a
connector that gates a tool we were already not exposing.
Stays stable across:
* Process restarts (tool names + descriptions are static).
* Different replicas (everyone gets the same hash for the same
inputs).
"""
tool_descriptors = sorted(
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
)
connectors = sorted(available_connectors or [])
doc_types = sorted(available_document_types or [])
return stable_hash(tool_descriptors, connectors, doc_types)
def flags_signature(flags: Any) -> str:
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
Frozen dataclasses are deterministically reprable, so a SHA1 of their
repr is a stable fingerprint. Restart safe (flags are read once at
process boot).
"""
return stable_hash(repr(flags))
def system_prompt_hash(system_prompt: str) -> str:
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
return hashlib.sha1(
system_prompt.encode("utf-8", errors="replace"),
usedforsecurity=False,
).hexdigest()
# ---------------------------------------------------------------------------
# Cache implementation
# ---------------------------------------------------------------------------
@dataclass
class _Entry:
value: Any
created_at: float
last_used_at: float
class _AgentCache:
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
NOT THREAD-SAFE in the multithreading sense designed for a single
asyncio event loop. Uvicorn runs one event loop per worker process,
so this is fine; multi-worker deployments simply each maintain their
own cache.
"""
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
self._maxsize = maxsize
self._ttl = ttl_seconds
self._entries: OrderedDict[str, _Entry] = OrderedDict()
# One lock per key — guards "build" so concurrent cold misses on
# the same key wait for the first build instead of all racing.
self._locks: dict[str, asyncio.Lock] = {}
def _now(self) -> float:
return time.monotonic()
def _is_fresh(self, entry: _Entry) -> bool:
return (self._now() - entry.created_at) < self._ttl
def _evict_if_full(self) -> None:
while len(self._entries) >= self._maxsize:
evicted_key, _ = self._entries.popitem(last=False)
self._locks.pop(evicted_key, None)
_perf_log.info(
"[agent_cache] evict key=%s reason=lru size=%d",
_short(evicted_key),
len(self._entries),
)
def _touch(self, key: str, entry: _Entry) -> None:
entry.last_used_at = self._now()
self._entries.move_to_end(key, last=True)
async def get_or_build(
self,
key: str,
*,
builder: Callable[[], Awaitable[Any]],
) -> Any:
"""Return the cached value for ``key`` or call ``builder()`` to make it.
``builder`` MUST be idempotent concurrent cold misses on the
same key collapse to a single ``builder()`` call (the others
wait on the in-flight lock and observe the populated entry on
wake).
"""
# Fast path: hot hit.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
# Stale entry — drop it; rebuild below.
if entry is not None and not self._is_fresh(entry):
_perf_log.info(
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
_short(key),
self._now() - entry.created_at,
self._ttl,
)
self._entries.pop(key, None)
# Slow path: serialize concurrent misses for the same key.
lock = self._locks.setdefault(key, asyncio.Lock())
async with lock:
# Double-check after acquiring the lock — another waiter may
# have populated the entry while we slept.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
t0 = time.perf_counter()
try:
value = await builder()
except BaseException:
# Don't cache failed builds; let the next caller retry.
_perf_log.warning(
"[agent_cache] build_failed key=%s elapsed=%.3fs",
_short(key),
time.perf_counter() - t0,
)
raise
elapsed = time.perf_counter() - t0
# Insert + evict.
self._evict_if_full()
now = self._now()
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
self._entries.move_to_end(key, last=True)
_perf_log.info(
"[agent_cache] miss key=%s build=%.3fs size=%d",
_short(key),
elapsed,
len(self._entries),
)
return value
def invalidate(self, key: str) -> bool:
"""Drop a single entry; return True if anything was removed."""
removed = self._entries.pop(key, None) is not None
self._locks.pop(key, None)
if removed:
_perf_log.info(
"[agent_cache] invalidate key=%s size=%d",
_short(key),
len(self._entries),
)
return removed
def invalidate_prefix(self, prefix: str) -> int:
"""Drop every entry whose key starts with ``prefix``. Returns count."""
keys = [k for k in self._entries if k.startswith(prefix)]
for k in keys:
self._entries.pop(k, None)
self._locks.pop(k, None)
if keys:
_perf_log.info(
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
_short(prefix),
len(keys),
len(self._entries),
)
return len(keys)
def clear(self) -> None:
n = len(self._entries)
self._entries.clear()
self._locks.clear()
if n:
_perf_log.info("[agent_cache] clear removed=%d", n)
def stats(self) -> dict[str, Any]:
return {
"size": len(self._entries),
"maxsize": self._maxsize,
"ttl_seconds": self._ttl,
}
def _short(key: str, n: int = 16) -> str:
"""Truncate keys for log lines so they don't blow up log volume."""
return key if len(key) <= n else f"{key[:n]}..."
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
def get_cache() -> _AgentCache:
"""Return the process-wide compiled-agent cache singleton."""
return _cache
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
"""Replace the singleton with a fresh cache. Tests only."""
global _cache
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
return _cache
__all__ = [
"flags_signature",
"get_cache",
"reload_for_tests",
"stable_hash",
"system_prompt_hash",
"tools_signature",
]

View file

@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
FlattenSystemMessageMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
@ -330,23 +338,39 @@ async def create_surfsense_deep_agent(
else None,
)
# Discover available connectors and document types for this search space
# Discover available connectors and document types for this search space.
#
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
# SQLAlchemy explicitly forbids concurrent operations on the same session
# ("This session is provisioning a new connection; concurrent operations
# are not permitted on the same session"). The Phase 1.4 in-process TTL
# cache in ``connector_service`` already collapses the warm path to a
# near-zero pair of dict lookups, so sequential awaits cost nothing in
# the common case while remaining correct on cold cache misses.
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
connector_types = await connector_service.get_available_connectors(
search_space_id
)
if connector_types:
available_connectors = _map_connectors_to_searchable_types(connector_types)
try:
connector_types_result = await connector_service.get_available_connectors(
search_space_id
)
if connector_types_result:
available_connectors = _map_connectors_to_searchable_types(
connector_types_result
)
except Exception as e:
logging.warning("Failed to discover available connectors: %s", e)
available_document_types = await connector_service.get_available_document_types(
search_space_id
)
except Exception as e:
try:
available_document_types = (
await connector_service.get_available_document_types(search_space_id)
)
except Exception as e:
logging.warning("Failed to discover available document types: %s", e)
except Exception as e: # pragma: no cover - defensive outer guard
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
@ -469,29 +493,77 @@ async def create_surfsense_deep_agent(
# entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive.
#
# PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
# on every per-request value that any middleware in the stack closes
# over in ``__init__`` — drop one and you risk leaking state across
# threads. Hits collapse this whole block to a microsecond lookup;
# misses pay the original CPU cost AND populate the cache.
config_id = agent_config.config_id if agent_config is not None else None
async def _build_agent() -> Any:
return await asyncio.to_thread(
_build_compiled_agent_blocking,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
# ``mentioned_document_ids`` is consumed by
# ``KnowledgePriorityMiddleware`` per turn via
# ``runtime.context`` (Phase 1.5). We still pass the
# caller-provided list here for the legacy fallback path
# (cache disabled / context not propagated) — the middleware
# drains its own copy after the first read so a cached graph
# never replays stale mentions.
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=_max_input_tokens,
flags=_flags,
checkpointer=checkpointer,
)
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
_build_compiled_agent_blocking,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=_max_input_tokens,
flags=_flags,
checkpointer=checkpointer,
)
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
# Cache key components — order matters only for human readability;
# the resulting hash is what's stored. Every component must
# rotate on a real shape change AND stay stable across identical
# invocations.
cache_key = stable_hash(
"v1", # schema version of the key — bump if components change
config_id,
thread_id,
user_id,
search_space_id,
visibility,
filesystem_selection.mode,
anon_session_id,
tools_signature(
tools,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
flags_signature(_flags),
system_prompt_hash(final_system_prompt),
_max_input_tokens,
# ``mentioned_document_ids`` deliberately omitted — middleware
# reads it from ``runtime.context`` (Phase 1.5).
)
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
else:
agent = await _build_agent()
_perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs",
"[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
time.perf_counter() - _t0,
"on"
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
else "off",
)
_perf_log.info(
@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking(
noop_mw,
retry_mw,
fallback_mw,
# Coalesce a multi-text-block system message into one block
# immediately before the model call. Sits innermost on the
# system-message-mutation chain so it observes every appender
# (todo / filesystem / skills / subagents …) and prevents
# OpenRouter→Anthropic from redistributing ``cache_control``
# across N blocks and tripping Anthropic's 4-breakpoint cap.
# See ``middleware/flatten_system.py`` for full rationale.
FlattenSystemMessageMiddleware(),
# Tool-call repair must run after model emits but before
# permission / dedup / doom-loop interpret the calls.
repair_mw,

View file

@ -1,10 +1,25 @@
"""
Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent.
This module defines the per-invocation context object passed to the SurfSense
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
The agent's compiled graph is the same across invocations (and cached by
``agent_cache``), so anything that varies per turn the user mentions a
specific document, the front-end issues a unique ``request_id``, etc.
MUST live on this context object instead of being captured into a
middleware ``__init__`` closure. Middlewares read fields back via
``runtime.context.<field>``; tools read them via ``runtime.context``.
This object is read inside both ``KnowledgePriorityMiddleware`` (for
``mentioned_document_ids``) and any future middleware that needs
per-request state without invalidating the compiled-agent cache.
"""
from typing import NotRequired, TypedDict
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
class FileOperationContractState(TypedDict):
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
turn_id: str
class SurfSenseContextSchema(TypedDict):
@dataclass
class SurfSenseContextSchema:
"""
Custom state schema for the SurfSense deep agent.
Per-invocation context for the SurfSense deep agent.
This extends the default agent state with custom fields.
The default state already includes:
- messages: Conversation history
- todos: Task list from TodoListMiddleware
- files: Virtual filesystem from FilesystemMiddleware
Defaults are chosen so the dataclass can be safely default-constructed
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
context is supplied see ``langgraph.runtime.Runtime``). All fields
are optional; consumers must None-check before reading.
We're adding fields needed for knowledge base search:
- search_space_id: The user's search space ID
- db_session: Database session (injected at runtime)
- connector_service: Connector service instance (injected at runtime)
Phase 1.5 fields:
search_space_id: Search space the request is scoped to.
mentioned_document_ids: KB documents the user @-mentioned this turn.
Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
file_operation_contract: One-shot file operation contract emitted
by ``FileIntentMiddleware`` for the upcoming turn.
turn_id / request_id: Correlation IDs surfaced by the streaming
task; populated for telemetry.
Phase 2 will extend with: thread_id, user_id, visibility,
filesystem_mode, anon_session_id, available_connectors,
available_document_types, created_by_id (everything currently captured
by middleware ``__init__`` closures).
"""
search_space_id: int
file_operation_contract: NotRequired[FileOperationContractState]
turn_id: NotRequired[str]
request_id: NotRequired[str]
# These are runtime-injected and won't be serialized
# db_session and connector_service are passed when invoking the agent
search_space_id: int | None = None
mentioned_document_ids: list[int] = field(default_factory=list)
file_operation_contract: FileOperationContractState | None = None
turn_id: str | None = None
request_id: str | None = None

View file

@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
SurfSense-native). They follow a "default-OFF for risky things,
default-ON for safe upgrades, master kill-switch for everything new" model.
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
image updates work even when older installs do not have newly introduced
environment variables. Risky/experimental integrations stay default OFF,
and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
Examples
--------
Local development (recommended for trying everything except doom-loop / selector):
Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else):
@ -60,32 +65,28 @@ class AgentFeatureFlags:
disable_new_agent_stack: bool = False
# Agent quality — context budget, retry/limits, name-repair, doom-loop
enable_context_editing: bool = False
enable_compaction_v2: bool = False
enable_retry_after: bool = False
enable_context_editing: bool = True
enable_compaction_v2: bool = True
enable_retry_after: bool = True
enable_model_fallback: bool = False
enable_model_call_limit: bool = False
enable_tool_call_limit: bool = False
enable_tool_call_repair: bool = False
enable_doom_loop: bool = (
False # Default OFF until UI handles permission='doom_loop'
)
enable_model_call_limit: bool = True
enable_tool_call_limit: bool = True
enable_tool_call_repair: bool = True
enable_doom_loop: bool = True
# Safety — permissions, concurrency, tool-set narrowing
enable_permission: bool = False # Default OFF for first deploy
enable_busy_mutex: bool = False
enable_permission: bool = True
enable_busy_mutex: bool = True
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
# Skills + subagents
enable_skills: bool = False
enable_specialized_subagents: bool = False
enable_kb_planner_runnable: bool = False
enable_skills: bool = True
enable_specialized_subagents: bool = True
enable_kb_planner_runnable: bool = True
# Snapshot / revert
enable_action_log: bool = False
enable_revert_route: bool = (
False # Backend ships before UI; route returns 503 until this flips
)
enable_action_log: bool = True
enable_revert_route: bool = True
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
@ -94,7 +95,7 @@ class AgentFeatureFlags:
# text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = False
enable_stream_parity_v2: bool = True
# Plugins
enable_plugin_loader: bool = False
@ -102,6 +103,41 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False
# Performance — compiled-agent cache (Phase 1 + Phase 2).
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
# graph if the cache key matches (LLM config + thread + tool surface +
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
# wall clock from ~4-5s to <50µs on cache hits.
#
# SAFETY (Phase 2 unblocked this default-on):
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
# short-lived ``AsyncSession`` instances per call via
# :data:`async_session_maker`. The factory still accepts ``db_session``
# for registry compatibility but ``del``'s it immediately — see any
# of those files' factory docstrings for the rationale. The ``llm``
# closure is per-(provider, model, config_id) which is already in
# the cache key, so the LLM is safe to share across cached hits of
# the same key. The KB priority middleware reads
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
# not its constructor closure, so the same compiled agent serves
# turns with different mention lists correctly.
#
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
# environment if a regression surfaces. The path is exercised by
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
enable_agent_cache: bool = True
# Phase 1 (deferred — measure first): pre-build & share the
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
# misses. Only helps when the outer cache MISSES (cache hits already
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
# until we have data showing cold misses are frequent enough to
# justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False
@classmethod
def from_env(cls) -> AgentFeatureFlags:
"""Read flags from environment.
@ -115,48 +151,76 @@ class AgentFeatureFlags:
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
"middleware is forced OFF for this build."
)
return cls(disable_new_agent_stack=True)
return cls(
disable_new_agent_stack=True,
enable_context_editing=False,
enable_compaction_v2=False,
enable_retry_after=False,
enable_model_fallback=False,
enable_model_call_limit=False,
enable_tool_call_limit=False,
enable_tool_call_repair=False,
enable_doom_loop=False,
enable_permission=False,
enable_busy_mutex=False,
enable_llm_tool_selector=False,
enable_skills=False,
enable_specialized_subagents=False,
enable_kb_planner_runnable=False,
enable_action_log=False,
enable_revert_route=False,
enable_stream_parity_v2=False,
enable_plugin_loader=False,
enable_otel=False,
enable_agent_cache=False,
enable_agent_cache_share_gp_subagent=False,
)
return cls(
disable_new_agent_stack=False,
# Agent quality
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool(
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
enable_tool_call_repair=_env_bool(
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
# Safety
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
),
# Skills + subagents
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
enable_specialized_subagents=_env_bool(
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
),
enable_kb_planner_runnable=_env_bool(
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
),
# Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
# Performance
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
enable_agent_cache_share_gp_subagent=_env_bool(
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
),
)
def any_new_middleware_enabled(self) -> bool:

View file

@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
# Provider mapping for LiteLLM model string construction.
#
# Single source of truth lives in
# :mod:`app.services.provider_capabilities` so the YAML loader (which
# runs during ``app.config`` class-body init) can resolve provider
# prefixes without dragging the agent / tools tree into module load
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
# tests) keep working unchanged.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@ -178,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog.
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
# which prefers OpenRouter's ``architecture.input_modalities`` and
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
supports_image_input: bool = True
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
@ -203,6 +191,12 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually
# contains at least one vision-capable deployment; the router
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
supports_image_input=True,
)
@classmethod
@ -216,10 +210,24 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
return cls(
provider=config.provider.value
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider),
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
return cls(
provider=provider_value,
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
@ -235,6 +243,16 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we
# ask LiteLLM (default-allow on unknown). The streaming
# safety net still blocks if the model is *explicitly*
# marked text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
),
)
@classmethod
@ -253,15 +271,46 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
# Explicit YAML override wins; otherwise derive from LiteLLM /
# OpenRouter modalities. The YAML loader already populates this
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
)
return cls(
provider=yaml_config.get("provider", "").upper(),
model_name=yaml_config.get("model_name", ""),
provider=provider,
model_name=model_name,
api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"),
custom_provider=yaml_config.get("custom_provider"),
custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
@ -276,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
supports_image_input=supports_image_input,
)

View file

@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
)
from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"FileIntentMiddleware",
"FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",

View file

@ -0,0 +1,233 @@
r"""Coalesce multi-block system messages into a single text block.
Several middlewares in our deepagent stack each call
``append_to_system_message`` on the way down to the model
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
``SkillsMiddleware``, ``SubAgentMiddleware`` ). By the time the
request reaches the LLM, the system message has 5+ separate text blocks.
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
request**, and we configure 2 injection points
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
the prepended ``request.system_message``, this middleware is the
defensive partner: it guarantees that "the system block" is *one*
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
OpenRouterAnthropic transformer can never multiply our budget into
several breakpoints by spreading ``cache_control`` across multiple
text blocks of a multi-block system content.
Without flattening we used to see::
OpenrouterException - {"error":{"message":"Provider returned error",
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
cache_control may be provided. Found 5."}}}
(Same error class documented in
https://github.com/BerriAI/litellm/issues/15696 and
https://github.com/BerriAI/litellm/issues/20485 the litellm-side fix
in PR #15395 covers the litellm transformer but does not protect us
when the OpenRouter SaaS itself does the redistribution.)
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
the first injection point from ``role: system`` to ``index: 0``)
neutralises the *primary* cause of the same 400 multiple
``SystemMessage``\ s injected by ``before_agent`` middlewares
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
turns, each tagged with ``cache_control`` by the ``role: system``
matcher. This middleware remains useful as defence-in-depth against
the multi-block redistribution path.
Placement: innermost on the system-message-mutation chain, after every
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
summarization, but before ``noop``/``retry``/``fallback`` so each retry
attempt sees a flattened payload. See ``chat_deepagent.py``.
Idempotent: a string-content system message is left untouched. A list
that contains anything other than plain text blocks (e.g. an image) is
also left untouched those are rare on system messages and we'd lose
the non-text payload by joining.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import SystemMessage
logger = logging.getLogger(__name__)
def _flatten_text_blocks(content: list[Any]) -> str | None:
"""Return joined text if every block is a plain ``{"type": "text"}``.
Returns ``None`` when the list contains anything that isn't a text
block we can safely concatenate (image, audio, file, non-standard
blocks, dicts with extra non-cache_control fields). The caller
leaves the original content untouched in that case rather than
silently dropping payload.
``cache_control`` on individual blocks is intentionally discarded
the whole point of flattening is to let LiteLLM's
``cache_control_injection_points`` re-place a single breakpoint on
the resulting one-block system content.
"""
chunks: list[str] = []
for block in content:
if isinstance(block, str):
chunks.append(block)
continue
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
chunks.append(text)
return "\n\n".join(chunks)
def _flattened_request(
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT] | None:
"""Return a request with system_message flattened, or ``None`` for no-op."""
sys_msg = request.system_message
if sys_msg is None:
return None
content = sys_msg.content
if not isinstance(content, list) or len(content) <= 1:
return None
flattened = _flatten_text_blocks(content)
if flattened is None:
return None
new_sys = SystemMessage(
content=flattened,
additional_kwargs=dict(sys_msg.additional_kwargs),
response_metadata=dict(sys_msg.response_metadata),
)
if sys_msg.id is not None:
new_sys.id = sys_msg.id
return request.override(system_message=new_sys)
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
"""One-line dump of cache_control-relevant request shape.
Temporary diagnostic to prove where the ``Found N`` cache_control
breakpoints are coming from when Anthropic 400s. Removed once the
root cause is confirmed and a fix is in place.
"""
sys_msg = request.system_message
if sys_msg is None:
sys_shape = "none"
elif isinstance(sys_msg.content, str):
sys_shape = f"str(len={len(sys_msg.content)})"
elif isinstance(sys_msg.content, list):
sys_shape = f"list(blocks={len(sys_msg.content)})"
else:
sys_shape = f"other({type(sys_msg.content).__name__})"
role_hist: list[str] = []
multi_block_msgs = 0
msgs_with_cc = 0
sys_msgs_in_history = 0
for m in request.messages:
mtype = getattr(m, "type", type(m).__name__)
role_hist.append(mtype)
if isinstance(m, SystemMessage):
sys_msgs_in_history += 1
c = getattr(m, "content", None)
if isinstance(c, list):
multi_block_msgs += 1
for blk in c:
if isinstance(blk, dict) and "cache_control" in blk:
msgs_with_cc += 1
break
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
msgs_with_cc += 1
tools = request.tools or []
tools_with_cc = 0
for t in tools:
if isinstance(t, dict) and (
"cache_control" in t or "cache_control" in t.get("function", {})
):
tools_with_cc += 1
return (
f"sys={sys_shape} msgs={len(request.messages)} "
f"sys_msgs_in_history={sys_msgs_in_history} "
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
f"roles={role_hist[-8:]}"
)
class FlattenSystemMessageMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Collapse a multi-text-block system message to a single string.
Sits innermost on the system-message-mutation chain so it observes
every middleware's contribution. Has no other side effect — the
body of every block is preserved, just joined with ``"\\n\\n"``.
"""
def __init__(self) -> None:
super().__init__()
self.tools = []
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return handler(flattened)
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return await handler(flattened)
return await handler(request)
__all__ = [
"FlattenSystemMessageMiddleware",
"_flatten_text_blocks",
"_flattened_request",
]

View file

@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc:
return self._anon_priority(state, anon_doc)
return await self._authenticated_priority(state, messages, user_text)
return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority(
self,
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
# Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids:
self.mentioned_document_ids = []
elif self.mentioned_document_ids:
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
mentioned_results: list[dict[str, Any]] = []
if self.mentioned_document_ids:
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
document_ids=self.mentioned_document_ids,
document_ids=mention_ids,
search_space_id=self.search_space_id,
)
self.mentioned_document_ids = []
if is_recency:
doc_types = _resolve_search_types(

View file

@ -1,4 +1,4 @@
"""LiteLLM-native prompt caching configuration for SurfSense agents.
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
@ -17,8 +17,20 @@ Coverage:
We inject **two** breakpoints per request:
- ``role: system`` pins the SurfSense system prompt (provider variant,
citation rules, tool catalog, KB tree, skills metadata) into the cache.
- ``index: 0`` pins the SurfSense system prompt at the head of the
request (provider variant, citation rules, tool catalog, KB tree,
skills metadata). The langchain agent factory always prepends
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouterAnthropic.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
@ -51,11 +63,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Two-breakpoint policy: system + latest message. See module docstring for
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
# use 2 here, leaving headroom for Phase-2 tool caching.
# Two-breakpoint policy: head-of-request + latest message. See module
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
# anonymous-doc) insert ``SystemMessage`` instances into
# ``state["messages"]`` that accumulate across turns. With
# ``role: system`` the LiteLLM hook would tag *every* one of them with
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
# always targets the langchain-prepended ``request.system_message``
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
# block), giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "role": "system"},
{"location": "message", "index": 0},
{"location": "message", "index": -1},
)

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the create_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def create_confluence_page(
title: str,
@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
"""
logger.info(f"create_confluence_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Confluence accounts need re-authentication.",
"connector_type": "confluence",
}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Confluence accounts need re-authentication.",
"connector_type": "confluence",
}
result = request_approval(
action_type="confluence_page_creation",
tool_name="create_confluence_page",
params={
"title": title,
"content": content,
"space_id": space_id,
"connector_id": connector_id,
},
context=context,
)
result = request_approval(
action_type="confluence_page_creation",
tool_name="create_confluence_page",
params={
"title": title,
"content": content,
"space_id": space_id,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_content = result.params.get("content", content) or ""
final_space_id = result.params.get("space_id", space_id)
final_connector_id = result.params.get("connector_id", connector_id)
final_title = result.params.get("title", title)
final_content = result.params.get("content", content) or ""
final_space_id = result.params.get("space_id", space_id)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
return {"status": "error", "message": "Page title cannot be empty."}
if not final_space_id:
return {"status": "error", "message": "A space must be selected."}
if not final_title or not final_title.strip():
return {"status": "error", "message": "Page title cannot be empty."}
if not final_space_id:
return {"status": "error", "message": "A space must be selected."}
from sqlalchemy.future import select
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Confluence connector found.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
api_result = await client.create_page(
space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Confluence connector found.",
}
actual_connector_id = connector.id
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
return {
"status": "success",
"page_id": page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
api_result = await client.create_page(
space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"page_id": page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the delete_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_confluence_page(
page_title_or_id: str,
@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_deletion",
tool_name="delete_confluence_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, page_title_or_id
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_deletion",
tool_name="delete_confluence_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
await client.delete_page(final_page_id)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
if result.rejected:
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
message = f"Confluence page '{page_title}' deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
await client.delete_page(final_page_id)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
return {
"status": "success",
"page_id": final_page_id,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Confluence page '{page_title}' deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"page_id": final_page_id,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the update_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def update_confluence_page(
page_title_or_id: str,
@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title_or_id
)
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
current_title = page_data["page_title"]
current_body = page_data.get("body", "")
current_version = page_data.get("version", 1)
document_id = page_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_update",
tool_name="update_confluence_page",
params={
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
current_title = page_data["page_title"]
current_body = page_data.get("body", "")
current_version = page_data.get("version", 1)
document_id = page_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_update",
tool_name="update_confluence_page",
params={
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_title = result.params.get("new_title", new_title) or current_title
final_content = result.params.get("new_content", new_content)
if final_content is None:
final_content = current_body
final_version = result.params.get("version", current_version)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
final_page_id = result.params.get("page_id", page_id)
final_title = result.params.get("new_title", new_title) or current_title
final_content = result.params.get("new_content", new_content)
if final_content is None:
final_content = current_body
final_version = result.params.get("version", current_version)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
final_document_id = result.params.get("document_id", document_id)
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
api_result = await client.update_page(
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
"status": "error",
"message": "No connector found for this page.",
}
raise
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
api_result = await client.update_page(
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
)
else:
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"page_id": final_page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
}
return {
"status": "success",
"page_id": final_page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
search_space_id: int,
user_id: str,
) -> StructuredTool:
"""Factory function to create the get_connected_accounts tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to scope account discovery to.
user_id: User ID to scope account discovery to.
Returns:
Configured StructuredTool for connected-accounts discovery.
"""
del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
)
)
connectors = result.scalars().all()
connectors = result.scalars().all()
if not connectors:
return [
{
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
if not connectors:
return [
{
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
}
]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
]
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return accounts
return StructuredTool(
name="get_connected_accounts",

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_discord_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_discord_channels tool
"""
del db_session # per-call session — see docstring
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
Returns:
Dictionary with status and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
try:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
guild_id = get_guild_id(connector)
if not guild_id:
return {
"status": "error",
"message": "No guild ID in Discord connector config.",
}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
guild_id = get_guild_id(connector)
if not guild_id:
return {
"status": "error",
"message": "No guild ID in Discord connector config.",
}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {
"status": "success",
"guild_id": guild_id,
"channels": channels,
"total": len(channels),
}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {
"status": "success",
"guild_id": guild_id,
"channels": channels,
"total": len(channels),
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_discord_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_discord_messages tool
"""
del db_session # per-call session — see docstring
@tool
async def read_discord_messages(
channel_id: str,
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
limit = min(limit, 50)
try:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
token = get_bot_token(connector)
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
)
return {
"status": "success",
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
return {
"status": "success",
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_discord_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_discord_message tool
"""
del db_session # per-call session — see docstring
@tool
async def send_discord_message(
channel_id: str,
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
@ -47,64 +65,65 @@ def create_send_discord_message_tool(
}
try:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
result = request_approval(
action_type="discord_send_message",
tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
result = request_approval(
action_type="discord_send_message",
tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to send messages in this channel.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to send messages in this channel.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_dropbox_file(
name: str,
@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Dropbox accounts need re-authentication.",
"connector_type": "dropbox",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = DropboxClient(session=db_session, connector_id=cid)
items, err = await client.list_folder("")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{
"folder_path": item.get("path_lower", ""),
"name": item["name"],
}
for item in items
if item.get(".tag") == "folder" and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
"supported_types": _SUPPORTED_TYPES,
}
result = request_approval(
action_type="dropbox_file_creation",
tool_name="create_dropbox_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_path": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_path = result.params.get("parent_folder_path")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_extension(final_name, final_file_type)
if final_connector_id is not None:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
connectors = result.scalars().all()
if not connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid.",
if not connectors:
return {
"status": "error",
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Dropbox accounts need re-authentication.",
"connector_type": "dropbox",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = DropboxClient(session=db_session, connector_id=cid)
items, err = await client.list_folder("")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{
"folder_path": item.get("path_lower", ""),
"name": item["name"],
}
for item in items
if item.get(".tag") == "folder" and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
"supported_types": _SUPPORTED_TYPES,
}
client = DropboxClient(session=db_session, connector_id=connector.id)
parent_path = final_parent_folder_path or ""
file_path = (
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
)
if final_file_type == "paper":
created = await client.create_paper_doc(file_path, final_content or "")
file_id = created.get("file_id", "")
web_url = created.get("url", "")
else:
docx_bytes = _markdown_to_docx(final_content or "")
created = await client.upload_file(
file_path, docx_bytes, mode="add", autorename=True
result = request_approval(
action_type="dropbox_file_creation",
tool_name="create_dropbox_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_path": None,
},
context=context,
)
file_id = created.get("id", "")
web_url = ""
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
kb_message_suffix = ""
try:
from app.services.dropbox import DropboxKBSyncService
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_path = result.params.get("parent_folder_path")
kb_service = DropboxKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=file_id,
file_name=final_name,
file_path=file_path,
web_url=web_url,
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_extension(final_name, final_file_type)
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connector = result.scalars().first()
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
connector = connectors[0]
return {
"status": "success",
"file_id": file_id,
"name": final_name,
"web_url": web_url,
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
}
if not connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid.",
}
client = DropboxClient(session=db_session, connector_id=connector.id)
parent_path = final_parent_folder_path or ""
file_path = (
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
)
if final_file_type == "paper":
created = await client.create_paper_doc(
file_path, final_content or ""
)
file_id = created.get("file_id", "")
web_url = created.get("url", "")
else:
docx_bytes = _markdown_to_docx(final_content or "")
created = await client.upload_file(
file_path, docx_bytes, mode="add", autorename=True
)
file_id = created.get("id", "")
web_url = ""
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
kb_message_suffix = ""
try:
from app.services.dropbox import DropboxKBSyncService
kb_service = DropboxKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=file_id,
file_name=final_name,
file_path=file_path,
web_url=web_url,
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": file_id,
"name": final_name,
"web_url": web_url,
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_dropbox_file(
file_name: str,
@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(
cast(
Document.document_metadata["dropbox_file_name"],
String,
)
)
== func.lower(file_name),
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
)
document = doc_result.scalars().first()
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed Dropbox files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
meta = document.document_metadata or {}
file_path = meta.get("dropbox_path")
file_id = meta.get("dropbox_file_id")
document_id = document.id
if not file_path:
return {
"status": "error",
"message": "File path is missing. Please re-index the file.",
}
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
if not document:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(
cast(
Document.document_metadata["dropbox_file_name"],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
)
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Dropbox connector not found or access denied.",
}
document = doc_result.scalars().first()
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "dropbox",
}
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed Dropbox files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
context = {
"file": {
"file_id": file_id,
"file_path": file_path,
"name": file_name,
"document_id": document_id,
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
result = request_approval(
action_type="dropbox_file_trash",
tool_name="delete_dropbox_file",
params={
"file_path": file_path,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
meta = document.document_metadata or {}
file_path = meta.get("dropbox_path")
file_id = meta.get("dropbox_file_id")
document_id = document.id
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
if not file_path:
return {
"status": "error",
"message": "File path is missing. Please re-index the file.",
}
final_file_path = result.params.get("file_path", file_path)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if final_connector_id != connector.id:
result = await db_session.execute(
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid or has been disconnected.",
"message": "Dropbox connector not found or access denied.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
)
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "dropbox",
}
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
await client.delete_file(final_file_path)
context = {
"file": {
"file_id": file_id,
"file_path": file_path,
"name": file_name,
"document_id": document_id,
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
logger.info(f"Dropbox file deleted: path={final_file_path}")
trash_result: dict[str, Any] = {
"status": "success",
"file_id": file_id,
"message": f"Successfully deleted '{file_name}' from Dropbox.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File deleted, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
result = request_approval(
action_type="dropbox_file_trash",
tool_name="delete_dropbox_file",
params={
"file_path": file_path,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
return trash_result
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_path = result.params.get("file_path", file_path)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
)
client = DropboxClient(
session=db_session, connector_id=actual_connector_id
)
await client.delete_file(final_file_path)
logger.info(f"Dropbox file deleted: path={final_file_path}")
trash_result: dict[str, Any] = {
"status": "success",
"file_id": file_id,
"message": f"Successfully deleted '{file_name}' from Dropbox.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File deleted, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
}
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{prefix}/{model_name}"
@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
cfg.get("provider", ""),
cfg["model_name"],
cfg.get("custom_provider"),
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
db_cfg.provider.value,
db_cfg.model_name,
db_cfg.custom_provider,
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:

View file

@ -0,0 +1,41 @@
from typing import Any
from app.db import SearchSourceConnector
from app.services.composio_service import ComposioService
def split_recipients(value: str | None) -> list[str]:
if not value:
return []
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
def unwrap_composio_data(data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def execute_composio_gmail_tool(
connector: SearchSourceConnector,
user_id: str,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return None, "Composio connected account ID not found for this Gmail connector."
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return unwrap_composio_data(result.get("data")), None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool
async def create_gmail_draft(
to: str,
@ -57,246 +75,276 @@ def create_create_gmail_draft_tool(
"""
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
)
connector = result.scalars().first()
if not connector:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
actual_connector_id = connector.id
logger.info(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
logger.info(
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
}
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
logger.info(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
created, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_CREATE_EMAIL_DRAFT",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(created, dict):
created = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
try:
from sqlalchemy.orm.attributes import flag_modified
logger.info(f"Gmail draft created: id={created.get('id')}")
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
logger.info(f"Gmail draft created: id={created.get('id')}")
kb_service = GmailKBSyncService(db_session)
draft_message = created.get("message", {})
kb_result = await kb_service.sync_after_create(
message_id=draft_message.get("id", ""),
thread_id=draft_message.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
draft_id=created.get("id"),
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
draft_message = created.get("message", {})
kb_result = await kb_service.sync_after_create(
message_id=draft_message.get("id", ""),
thread_id=draft_message.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
draft_id=created.get("id"),
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": created.get("id"),
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
return {
"status": "success",
"draft_id": created.get("id"),
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -5,7 +5,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
@ -32,60 +49,115 @@ def create_read_gmail_email_tool(
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
connector = result.scalars().first()
if not connector:
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
return {"status": "error", "message": error}
if not detail:
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.agents.new_chat.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
service = ComposioService()
detail, error = await service.get_gmail_message_detail(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if error:
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
summary = _format_gmail_summary(detail)
content = (
f"# {summary['subject']}\n\n"
f"**From:** {summary['from']}\n"
f"**To:** {summary['to']}\n"
f"**Date:** {summary['date']}\n\n"
f"## Message Content\n\n"
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {summary['message_id']}\n"
f"- **Thread ID:** {summary['thread_id']}\n"
)
return {
"status": "success",
"message_id": summary["message_id"] or message_id,
"content": content,
}
from app.agents.new_chat.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
content = gmail.format_message_to_markdown(detail)
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
"status": "success",
"message_id": message_id,
"content": content,
}
content = gmail.format_message_to_markdown(detail)
return {"status": "success", "message_id": message_id, "content": content}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,7 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
raise ValueError("Composio connectors must use Composio tool execution.")
from google.oauth2.credentials import Credentials
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
)
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
headers = message.get("payload", {}).get("headers", [])
return {
header.get("name", "").lower(): header.get("value", "")
for header in headers
if isinstance(header, dict)
}
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
headers = _gmail_headers(message)
return {
"message_id": message.get("id") or message.get("messageId"),
"thread_id": message.get("threadId"),
"subject": message.get("subject") or headers.get("subject", "No Subject"),
"from": message.get("sender") or headers.get("from", "Unknown"),
"to": message.get("to") or headers.get("to", ""),
"date": message.get("messageTimestamp") or headers.get("date", ""),
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
"labels": message.get("labelIds", []),
}
async def _search_composio_gmail(
connector: SearchSourceConnector,
user_id: str,
query: str,
max_results: int,
) -> dict[str, Any]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.services.composio_service import ComposioService
service = ComposioService()
messages, _next_token, _estimate, error = await service.get_gmail_messages(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=max_results,
)
if error:
return {"status": "error", "message": error}
emails = [_format_gmail_summary(message) for message in messages]
return {
"status": "success",
"emails": emails,
"total": len(emails),
"message": "No emails found." if not emails else None,
}
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the search_gmail tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_gmail tool
"""
del db_session # per-call session — see docstring
@tool
async def search_gmail(
query: str,
@ -90,83 +159,92 @@ def create_search_gmail_tool(
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
connector = result.scalars().first()
if not connector:
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
return {"status": "error", "message": error}
if not messages_list:
return {
"status": "success",
"emails": [],
"total": 0,
"message": "No emails found.",
}
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
return await _search_composio_gmail(
connector, str(user_id), query, max_results
)
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append(
{
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
}
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
return {"status": "success", "emails": emails, "total": len(emails)}
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not messages_list:
return {
"status": "success",
"emails": [],
"total": 0,
"message": "No emails found.",
}
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append(
{
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
}
)
return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def send_gmail_email(
to: str,
@ -58,247 +76,277 @@ def create_send_gmail_email_tool(
"""
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
)
connector = result.scalars().first()
if not connector:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
actual_connector_id = connector.id
logger.info(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
logger.info(
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
}
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
logger.info(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
sent, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_SEND_EMAIL",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(sent, dict):
sent = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
try:
from sqlalchemy.orm.attributes import flag_modified
logger.info(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
logger.info(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"message_id": sent.get("id"),
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
}
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"message_id": sent.get("id"),
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the trash_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured trash_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def trash_gmail_email(
email_subject_or_id: str,
@ -55,244 +73,261 @@ def create_trash_gmail_email_tool(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, email_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, email_subject_or_id
)
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = context["account"]["id"]
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
if not message_id:
return {
"status": "error",
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="gmail_email_trash",
tool_name="trash_gmail_email",
params={
"message_id": message_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = context["account"]["id"]
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
final_message_id = result.params.get("message_id", message_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
if not message_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
logger.info(
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="gmail_email_trash",
tool_name="trash_gmail_email",
params={
"message_id": message_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
if result.rejected:
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
"status": "rejected",
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}")
trash_result: dict[str, Any] = {
"status": "success",
"message_id": final_message_id,
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
final_message_id = result.params.get("message_id", message_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
return trash_result
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
_trashed, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_MOVE_TO_TRASH",
{"user_id": "me", "message_id": final_message_id},
)
if error:
raise RuntimeError(error)
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}")
trash_result: dict[str, Any] = {
"status": "success",
"message_id": final_message_id,
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the update_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool
async def update_gmail_draft(
draft_subject_or_id: str,
@ -76,294 +94,329 @@ def create_update_gmail_draft_tool(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, draft_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
result = request_approval(
action_type="gmail_draft_update",
tool_name="update_gmail_draft",
params={
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", final_to_default)
final_subject = result.params.get("subject", final_subject_default)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_draft_id = result.params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this draft.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, draft_subject_or_id
)
from googleapiclient.discovery import build
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
gmail_service = build("gmail", "v1", credentials=creds)
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
"Gmail account %s has expired authentication",
account.get("id"),
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
"status": "auth_error",
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
result = request_approval(
action_type="gmail_draft_update",
tool_name="update_gmail_draft",
params={
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", final_to_default)
final_subject = result.params.get("subject", final_subject_default)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_draft_id = result.params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
"message": "No connector found for this draft.",
}
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}")
from sqlalchemy.future import select
kb_message_suffix = ""
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import Document
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
doc_result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id)
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
document = doc_result.scalars().first()
if document:
document.source_markdown = final_body
document.title = final_subject
meta = dict(document.document_metadata or {})
meta["subject"] = final_subject
meta["draft_id"] = updated.get("id", final_draft_id)
updated_msg = updated.get("message", {})
if updated_msg.get("id"):
meta["message_id"] = updated_msg["id"]
document.document_metadata = meta
flag_modified(document, "document_metadata")
await db_session.commit()
kb_message_suffix = (
" Your knowledge base has also been updated."
)
logger.info(
f"KB document {document_id} updated for draft {final_draft_id}"
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
if is_composio_gmail:
final_draft_id = await _find_composio_draft_id_by_message(
connector, user_id, message_id
)
else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
from googleapiclient.discovery import build
return {
"status": "success",
"draft_id": updated.get("id"),
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
gmail_service = build("gmail", "v1", credentials=creds)
final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
updated, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_UPDATE_DRAFT",
{
"user_id": "me",
"draft_id": final_draft_id,
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(updated, dict):
updated = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}")
kb_message_suffix = ""
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import Document
doc_result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
document.source_markdown = final_body
document.title = final_subject
meta = dict(document.document_metadata or {})
meta["subject"] = final_subject
meta["draft_id"] = updated.get("id", final_draft_id)
updated_msg = updated.get("message", {})
if updated_msg.get("id"):
meta["message_id"] = updated_msg["id"]
document.document_metadata = meta
flag_modified(document, "document_metadata")
await db_session.commit()
kb_message_suffix = (
" Your knowledge base has also been updated."
)
logger.info(
f"KB document {document_id} updated for draft {final_draft_id}"
)
else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": updated.get("id"),
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}")
return None
async def _find_composio_draft_id_by_message(
connector: Any, user_id: str, message_id: str
) -> str | None:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await execute_composio_gmail_tool(
connector, user_id, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def create_calendar_event(
summary: str,
@ -60,254 +78,294 @@ def create_create_calendar_event_tool(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Calendar accounts have expired authentication"
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for creating calendar event: summary='{summary}'"
)
result = request_approval(
action_type="google_calendar_event_creation",
tool_name="create_calendar_event",
params={
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": description,
"location": location,
"attendees": attendees,
"timezone": context.get("timezone"),
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
final_summary = result.params.get("summary", summary)
final_start_datetime = result.params.get("start_datetime", start_datetime)
final_end_datetime = result.params.get("end_datetime", end_datetime)
final_description = result.params.get("description", description)
final_location = result.params.get("location", location)
final_attendees = result.params.get("attendees", attendees)
final_connector_id = result.params.get("connector_id")
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Event summary cannot be empty."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Calendar accounts have expired authentication"
)
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for creating calendar event: summary='{summary}'"
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
result = request_approval(
action_type="google_calendar_event_creation",
tool_name="create_calendar_event",
params={
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": description,
"location": location,
"attendees": attendees,
"timezone": context.get("timezone"),
"connector_id": None,
},
context=context,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
final_summary = result.params.get("summary", summary)
final_start_datetime = result.params.get(
"start_datetime", start_datetime
)
final_end_datetime = result.params.get("end_datetime", end_datetime)
final_description = result.params.get("description", description)
final_location = result.params.get("location", location)
final_attendees = result.params.get("attendees", attendees)
final_connector_id = result.params.get("connector_id")
if not final_summary or not final_summary.strip():
return {
"status": "error",
"message": "Event summary cannot be empty.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
actual_connector_id = connector.id
return {
"status": "success",
"event_id": created.get("id"),
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
}
logger.info(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params = {
"calendar_id": "primary",
"summary": final_summary,
"start_datetime": final_start_datetime,
"end_datetime": final_end_datetime,
"timezone": tz,
"attendees": final_attendees or [],
}
if final_description:
composio_params["description"] = final_description
if final_location:
composio_params["location"] = final_location
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_CREATE_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
created = composio_result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"event_id": created.get("id"),
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_calendar_event(
event_title_or_id: str,
@ -54,240 +72,258 @@ def create_delete_calendar_event_tool(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Calendar account %s has expired authentication",
account.get("id"),
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, event_title_or_id
)
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Calendar account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_calendar_event_deletion",
tool_name="delete_calendar_event",
params={
"event_id": event_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
if not event_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
logger.info(
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_calendar_event_deletion",
tool_name="delete_calendar_event",
params={
"event_id": event_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
if result.rejected:
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
"status": "rejected",
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}")
delete_result: dict[str, Any] = {
"status": "success",
"event_id": final_event_id,
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
return delete_result
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_DELETE_EVENT",
params={
"calendar_id": "primary",
"event_id": final_event_id,
},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}")
delete_result: dict[str, Any] = {
"status": "success",
"event_id": final_event_id,
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
)
return delete_result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
]
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
if "T" in value:
return value
time = "23:59:59" if is_end else "00:00:00"
return f"{value}T{time}Z"
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return events
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the search_calendar_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_calendar_events tool
"""
del db_session # per-call session — see docstring
@tool
async def search_calendar_events(
start_date: str,
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Calendar tool not properly configured.",
@ -47,76 +93,85 @@ def create_search_calendar_events_tool(
max_results = min(max_results, 50)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
)
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import GoogleCalendarConnector
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
return {
"status": "auth_error",
"message": error,
"connector_type": "google_calendar",
}
if "no events found" in error.lower():
return {
"status": "success",
"events": [],
"total": 0,
"message": error,
}
return {"status": "error", "message": error}
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
from app.services.composio_service import ComposioService
return {"status": "success", "events": events, "total": len(events)}
events_raw, error = await ComposioService().get_calendar_events(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
time_min=_to_calendar_boundary(start_date, is_end=False),
time_max=_to_calendar_boundary(end_date, is_end=True),
max_results=max_results,
)
if not events_raw and not error:
error = "No events found in the specified date range."
else:
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import (
GoogleCalendarConnector,
)
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "google_calendar",
}
if "no events found" in error.lower():
return {
"status": "success",
"events": [],
"total": 0,
"message": error,
}
return {"status": "error", "message": error}
events = _format_calendar_events(events_raw)
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the update_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def update_calendar_event(
event_title_or_id: str,
@ -74,272 +92,317 @@ def create_update_calendar_event_tool(
"""
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if context.get("auth_expired"):
logger.warning("Google Calendar account has expired authentication")
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
result = request_approval(
action_type="google_calendar_event_update",
tool_name="update_calendar_event",
params={
"event_id": event_id,
"document_id": document_id,
"connector_id": connector_id_from_context,
"new_summary": new_summary,
"new_start_datetime": new_start_datetime,
"new_end_datetime": new_end_datetime,
"new_description": new_description,
"new_location": new_location,
"new_attendees": new_attendees,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_new_summary = result.params.get("new_summary", new_summary)
final_new_start_datetime = result.params.get(
"new_start_datetime", new_start_datetime
)
final_new_end_datetime = result.params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, event_title_or_id
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
logger.info(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
if context.get("auth_expired"):
logger.warning("Google Calendar account has expired authentication")
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
if not event_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
logger.info(
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
result = request_approval(
action_type="google_calendar_event_update",
tool_name="update_calendar_event",
params={
"event_id": event_id,
"document_id": document_id,
"connector_id": connector_id_from_context,
"new_summary": new_summary,
"new_start_datetime": new_start_datetime,
"new_end_datetime": new_end_datetime,
"new_description": new_description,
"new_location": new_location,
"new_attendees": new_attendees,
},
context=context,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
update_body["start"] = _build_time_body(
final_new_start_datetime, context
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(final_new_end_datetime, context)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
final_new_summary = result.params.get("new_summary", new_summary)
final_new_start_datetime = result.params.get(
"new_start_datetime", new_start_datetime
)
final_new_end_datetime = result.params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
if not update_body:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
connector = result.scalars().first()
if not connector:
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}")
actual_connector_id = connector.id
kb_message_suffix = ""
if document_id is not None:
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
logger.info(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
event_id=final_event_id,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return {
"status": "success",
"event_id": final_event_id,
"html_link": updated.get("htmlLink"),
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
}
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
update_body["start"] = _build_time_body(
final_new_start_datetime, context
)
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
if not update_body:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params: dict[str, Any] = {
"calendar_id": "primary",
"event_id": final_event_id,
}
if final_new_summary is not None:
composio_params["summary"] = final_new_summary
if final_new_start_datetime is not None:
composio_params["start_time"] = final_new_start_datetime
if final_new_end_datetime is not None:
composio_params["end_time"] = final_new_end_datetime
if final_new_description is not None:
composio_params["description"] = final_new_description
if final_new_location is not None:
composio_params["location"] = final_new_location
if final_new_attendees is not None:
composio_params["attendees"] = [
e.strip() for e in final_new_attendees if e.strip()
]
if not _is_date_only(
final_new_start_datetime or final_new_end_datetime or ""
):
composio_params["timezone"] = context.get("timezone", "UTC")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_PATCH_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
updated = composio_result.get("data", {})
if isinstance(updated, dict):
updated = updated.get("data", updated)
if isinstance(updated, dict):
updated = updated.get("response_data", updated)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}")
kb_message_suffix = ""
if document_id is not None:
try:
from app.services.google_calendar import (
GoogleCalendarKBSyncService,
)
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
event_id=final_event_id,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return {
"status": "success",
"event_id": final_event_id,
"html_link": updated.get("htmlLink"),
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured create_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_google_drive_file(
name: str,
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
@ -78,195 +98,232 @@ def create_create_google_drive_file_tool(
}
try:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Google Drive accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
)
result = request_approval(
action_type="google_drive_file_creation",
tool_name="create_google_drive_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
return {"status": "error", "message": context["error"]}
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
credentials=pre_built_creds,
)
try:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
"All Google Drive accounts have expired authentication"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
"status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
raise
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.google_drive import GoogleDriveKBSyncService
kb_service = GoogleDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=mime_type,
web_view_link=created.get("webViewLink"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
)
result = request_approval(
action_type="google_drive_file_creation",
tool_name="create_google_drive_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
params: dict[str, Any] = {
"name": final_name,
"mimeType": mime_type,
"fields": "id,name,webViewLink,mimeType",
}
if final_parent_folder_id:
params["parents"] = [final_parent_folder_id]
if final_content:
params["description"] = final_content[:4096]
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_CREATE_FILE",
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
created = result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
if not isinstance(created, dict):
created = {}
else:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.google_drive import GoogleDriveKBSyncService
kb_service = GoogleDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=mime_type,
web_view_link=created.get("webViewLink"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured delete_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_google_drive_file(
file_name: str,
@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, file_name
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Drive account %s has expired authentication",
account.get("id"),
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, file_name
)
return {
"status": "auth_error",
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
file = context["file"]
file_id = file["file_id"]
document_id = file.get("document_id")
connector_id_from_context = context["account"]["id"]
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
if not file_id:
return {
"status": "error",
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
logger.info(
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_drive_file_trash",
tool_name="delete_google_drive_file",
params={
"file_id": file_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
credentials=pre_built_creds,
)
try:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
"Google Drive account %s has expired authentication",
account.get("id"),
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
"status": "auth_error",
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
raise
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
file = context["file"]
file_id = file["file_id"]
document_id = file.get("document_id")
connector_id_from_context = context["account"]["id"]
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file['name']}' to trash.",
}
if not file_id:
return {
"status": "error",
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
logger.info(
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_drive_file_trash",
tool_name="delete_google_drive_file",
params={
"file_id": file_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
return trash_result
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_TRASH_FILE",
params={"file_id": final_file_id},
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
else:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file['name']}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
"create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the create_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Jira API
blocking.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def create_jira_issue(
project_key: str,
@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Jira accounts need re-authentication.",
"connector_type": "jira",
}
result = request_approval(
action_type="jira_issue_creation",
tool_name="create_jira_issue",
params={
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_project_key = result.params.get("project_key", project_key)
final_summary = result.params.get("summary", summary)
final_issue_type = result.params.get("issue_type", issue_type)
final_description = result.params.get("description", description)
final_priority = result.params.get("priority", priority)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Issue summary cannot be empty."}
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
connector = result.scalars().first()
if not connector:
return {"status": "error", "message": "No Jira connector found."}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Jira accounts need re-authentication.",
"connector_type": "jira",
}
result = request_approval(
action_type="jira_issue_creation",
tool_name="create_jira_issue",
params={
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
context=context,
)
connector = result.scalars().first()
if not connector:
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_project_key = result.params.get("project_key", project_key)
final_summary = result.params.get("summary", summary)
final_issue_type = result.params.get("issue_type", issue_type)
final_description = result.params.get("description", description)
final_priority = result.params.get("priority", priority)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
"message": "Issue summary cannot be empty.",
}
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
jira_client = await jira_history._get_jira_client()
api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
from sqlalchemy.future import select
issue_key = api_result.get("key", "")
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
from app.db import SearchSourceConnector, SearchSourceConnectorType
kb_message_suffix = ""
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=issue_key,
issue_identifier=issue_key,
issue_title=final_summary,
description=final_description,
state="To Do",
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Jira connector found.",
}
actual_connector_id = connector.id
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
return {
"status": "success",
"issue_key": issue_key,
"issue_url": issue_url,
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
jira_client = await jira_history._get_jira_client()
api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_key = api_result.get("key", "")
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
kb_message_suffix = ""
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=issue_key,
issue_identifier=issue_key,
issue_title=final_summary,
description=final_description,
state="To Do",
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_key": issue_key,
"issue_url": issue_url,
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the delete_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_jira_issue(
issue_title_or_key: str,
@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_deletion",
tool_name="delete_jira_issue",
params={
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, issue_title_or_key
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_deletion",
tool_name="delete_jira_issue",
params={
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
if result.rejected:
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
final_issue_key = result.params.get("issue_key", issue_key)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
message = f"Jira issue {final_issue_key} deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
return {
"status": "success",
"issue_key": final_issue_key,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Jira issue {final_issue_key} deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"issue_key": final_issue_key,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the update_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured update_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def update_jira_issue(
issue_title_or_key: str,
@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_update",
tool_name="update_jira_issue",
params={
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_summary = result.params.get("new_summary", new_summary)
final_description = result.params.get("new_description", new_description)
final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_title_or_key
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {}
if final_summary:
fields["summary"] = final_summary
if final_description is not None:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [{"type": "text", "text": final_description}],
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
],
}
if final_priority:
fields["priority"] = {"name": final_priority}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
if not fields:
return {"status": "error", "message": "No changes specified."}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
result = request_approval(
action_type="jira_issue_update",
tool_name="update_jira_issue",
params={
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
context=context,
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
if result.rejected:
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
raise
issue_url = (
f"{jira_history._base_url}/browse/{final_issue_key}"
if jira_history._base_url and final_issue_key
else ""
)
final_issue_key = result.params.get("issue_key", issue_key)
final_summary = result.params.get("new_summary", new_summary)
final_description = result.params.get(
"new_description", new_description
)
final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
kb_message_suffix = ""
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
from sqlalchemy.future import select
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_key,
user_id=user_id,
search_space_id=search_space_id,
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {}
if final_summary:
fields["summary"] = final_summary
if final_description is not None:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [
{"type": "text", "text": final_description}
],
}
],
}
if final_priority:
fields["priority"] = {"name": final_priority}
if not fields:
return {"status": "error", "message": "No changes specified."}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_url = (
f"{jira_history._base_url}/browse/{final_issue_key}"
if jira_history._base_url and final_issue_key
else ""
)
kb_message_suffix = ""
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_key,
user_id=user_id,
search_space_id=search_space_id,
)
else:
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"issue_key": final_issue_key,
"issue_url": issue_url,
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
}
return {
"status": "success",
"issue_key": final_issue_key,
"issue_url": issue_url,
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the create_linear_issue tool.
"""Factory function to create the create_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Database session for accessing the Linear connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
Returns:
Configured create_linear_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def create_linear_issue(
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
"""
logger.info(f"create_linear_issue called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
}
try:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
result = request_approval(
action_type="linear_issue_creation",
tool_name="create_linear_issue",
params={
"title": title,
"description": description,
"team_id": None,
"state_id": None,
"assignee_id": None,
"priority": None,
"label_ids": [],
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Linear issue creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_description = result.params.get("description", description)
final_team_id = result.params.get("team_id")
final_state_id = result.params.get("state_id")
final_assignee_id = result.params.get("assignee_id")
final_priority = result.params.get("priority")
final_label_ids = result.params.get("label_ids") or []
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {"status": "error", "message": "Issue title cannot be empty."}
if not final_team_id:
return {
"status": "error",
"message": "A team must be selected to create an issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
connector = result.scalars().first()
if not connector:
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
result = request_approval(
action_type="linear_issue_creation",
tool_name="create_linear_issue",
params={
"title": title,
"description": description,
"team_id": None,
"state_id": None,
"assignee_id": None,
"priority": None,
"label_ids": [],
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Linear issue creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_description = result.params.get("description", description)
final_team_id = result.params.get("team_id")
final_state_id = result.params.get("state_id")
final_assignee_id = result.params.get("assignee_id")
final_priority = result.params.get("priority")
final_label_ids = result.params.get("label_ids") or []
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
"message": "Issue title cannot be empty.",
}
actual_connector_id = connector.id
logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
if not final_team_id:
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
"message": "A team must be selected to create an issue.",
}
logger.info(f"Validated Linear connector: id={actual_connector_id}")
logger.info(
f"Creating Linear issue with final params: title='{final_title}'"
)
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.create_issue(
team_id=final_team_id,
title=final_title,
description=final_description,
state_id=final_state_id,
assignee_id=final_assignee_id,
priority=final_priority,
label_ids=final_label_ids if final_label_ids else None,
)
from sqlalchemy.future import select
if result.get("status") == "error":
logger.error(f"Failed to create Linear issue: {result.get('message')}")
return {"status": "error", "message": result.get("message")}
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger.info(
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
)
kb_message_suffix = ""
try:
from app.services.linear import LinearKBSyncService
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=result.get("id"),
issue_identifier=result.get("identifier", ""),
issue_title=result.get("title", final_title),
issue_url=result.get("url"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={actual_connector_id}")
return {
"status": "success",
"issue_id": result.get("id"),
"identifier": result.get("identifier"),
"url": result.get("url"),
"message": (result.get("message", "") + kb_message_suffix),
}
logger.info(
f"Creating Linear issue with final params: title='{final_title}'"
)
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.create_issue(
team_id=final_team_id,
title=final_title,
description=final_description,
state_id=final_state_id,
assignee_id=final_assignee_id,
priority=final_priority,
label_ids=final_label_ids if final_label_ids else None,
)
if result.get("status") == "error":
logger.error(
f"Failed to create Linear issue: {result.get('message')}"
)
return {"status": "error", "message": result.get("message")}
logger.info(
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
)
kb_message_suffix = ""
try:
from app.services.linear import LinearKBSyncService
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=result.get("id"),
issue_identifier=result.get("identifier", ""),
issue_title=result.get("title", final_title),
issue_url=result.get("url"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_id": result.get("id"),
"identifier": result.get("identifier"),
"url": result.get("url"),
"message": (result.get("message", "") + kb_message_suffix),
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the delete_linear_issue tool.
"""Factory function to create the delete_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Database session for accessing the Linear connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
Returns:
Configured delete_linear_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_linear_issue(
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
}
try:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, issue_ref
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
issue_identifier = context["issue"].get("identifier", "")
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="linear_issue_deletion",
tool_name="delete_linear_issue",
params={
"issue_id": issue_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Linear issue deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, issue_ref
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
issue_identifier = context["issue"].get("identifier", "")
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="linear_issue_deletion",
tool_name="delete_linear_issue",
params={
"issue_id": issue_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Linear issue deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(f"Validated Linear connector: id={actual_connector_id}")
else:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
"message": "No connector found for this issue.",
}
actual_connector_id = connector.id
logger.info(f"Validated Linear connector: id={actual_connector_id}")
else:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.archive_issue(issue_id=final_issue_id)
result = await linear_client.archive_issue(issue_id=final_issue_id)
logger.info(
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
)
logger.info(
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
)
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from app.db import Document
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if issue_identifier:
result["message"] = (
f"Issue {issue_identifier} archived successfully."
)
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} Also removed from the knowledge base."
)
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if issue_identifier:
result["message"] = (
f"Issue {issue_identifier} archived successfully."
)
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} Also removed from the knowledge base."
)
return result
return result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the update_linear_issue tool.
"""Factory function to create the update_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Database session for accessing the Linear connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
Returns:
Configured update_linear_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def update_linear_issue(
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
"""
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
}
try:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_ref
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for update context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
team = context.get("team", {})
new_state_id = _resolve_state(team, new_state_name)
new_assignee_id = _resolve_assignee(team, new_assignee_email)
new_label_ids = _resolve_labels(team, new_label_names)
logger.info(
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
)
result = request_approval(
action_type="linear_issue_update",
tool_name="update_linear_issue",
params={
"issue_id": issue_id,
"document_id": document_id,
"new_title": new_title,
"new_description": new_description,
"new_state_id": new_state_id,
"new_assignee_id": new_assignee_id,
"new_priority": new_priority,
"new_label_ids": new_label_ids,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Linear issue update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_document_id = result.params.get("document_id", document_id)
final_new_title = result.params.get("new_title", new_title)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_state_id = result.params.get("new_state_id", new_state_id)
final_new_assignee_id = result.params.get(
"new_assignee_id", new_assignee_id
)
final_new_priority = result.params.get("new_priority", new_priority)
final_new_label_ids: list[str] | None = result.params.get(
"new_label_ids", new_label_ids
)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
if not final_connector_id:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
async with async_session_maker() as db_session:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_ref
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={final_connector_id}")
logger.info(
f"Updating Linear issue with final params: issue_id={final_issue_id}"
)
linear_client = LinearConnector(
session=db_session, connector_id=final_connector_id
)
updated_issue = await linear_client.update_issue(
issue_id=final_issue_id,
title=final_new_title,
description=final_new_description,
state_id=final_new_state_id,
assignee_id=final_new_assignee_id,
priority=final_new_priority,
label_ids=final_new_label_ids,
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for update context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if updated_issue.get("status") == "error":
logger.error(
f"Failed to update Linear issue: {updated_issue.get('message')}"
)
return {
"status": "error",
"message": updated_issue.get("message"),
}
issue_id = context["issue"]["id"]
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
)
team = context.get("team", {})
new_state_id = _resolve_state(team, new_state_name)
new_assignee_id = _resolve_assignee(team, new_assignee_email)
new_label_ids = _resolve_labels(team, new_label_names)
if final_document_id is not None:
logger.info(
f"Updating knowledge base for document {final_document_id}..."
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
)
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_id,
user_id=user_id,
search_space_id=search_space_id,
result = request_approval(
action_type="linear_issue_update",
tool_name="update_linear_issue",
params={
"issue_id": issue_id,
"document_id": document_id,
"new_title": new_title,
"new_description": new_description,
"new_state_id": new_state_id,
"new_assignee_id": new_assignee_id,
"new_priority": new_priority,
"new_label_ids": new_label_ids,
"connector_id": connector_id_from_context,
},
context=context,
)
if kb_result["status"] == "success":
logger.info(
f"Knowledge base successfully updated for issue {final_issue_id}"
)
kb_message = " Your knowledge base has also been updated."
elif kb_result["status"] == "not_indexed":
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
else:
logger.warning(
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
)
kb_message = " Your knowledge base will be updated in the next scheduled sync."
else:
kb_message = ""
identifier = updated_issue.get("identifier")
default_msg = f"Issue {identifier} updated successfully."
return {
"status": "success",
"identifier": identifier,
"url": updated_issue.get("url"),
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
}
if result.rejected:
logger.info("Linear issue update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_document_id = result.params.get("document_id", document_id)
final_new_title = result.params.get("new_title", new_title)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_state_id = result.params.get("new_state_id", new_state_id)
final_new_assignee_id = result.params.get(
"new_assignee_id", new_assignee_id
)
final_new_priority = result.params.get("new_priority", new_priority)
final_new_label_ids: list[str] | None = result.params.get(
"new_label_ids", new_label_ids
)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
if not final_connector_id:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={final_connector_id}")
logger.info(
f"Updating Linear issue with final params: issue_id={final_issue_id}"
)
linear_client = LinearConnector(
session=db_session, connector_id=final_connector_id
)
updated_issue = await linear_client.update_issue(
issue_id=final_issue_id,
title=final_new_title,
description=final_new_description,
state_id=final_new_state_id,
assignee_id=final_new_assignee_id,
priority=final_new_priority,
label_ids=final_new_label_ids,
)
if updated_issue.get("status") == "error":
logger.error(
f"Failed to update Linear issue: {updated_issue.get('message')}"
)
return {
"status": "error",
"message": updated_issue.get("message"),
}
logger.info(
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
)
if final_document_id is not None:
logger.info(
f"Updating knowledge base for document {final_document_id}..."
)
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_id,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
logger.info(
f"Knowledge base successfully updated for issue {final_issue_id}"
)
kb_message = " Your knowledge base has also been updated."
elif kb_result["status"] == "not_indexed":
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
else:
logger.warning(
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
)
kb_message = " Your knowledge base will be updated in the next scheduled sync."
else:
kb_message = ""
identifier = updated_issue.get("identifier")
default_msg = f"Issue {identifier} updated successfully."
return {
"status": "success",
"identifier": identifier,
"url": updated_issue.get("url"),
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_luma_event tool
"""
del db_session # per-call session — see docstring
@tool
async def create_luma_event(
name: str,
@ -40,83 +58,86 @@ def create_create_luma_event_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
result = request_approval(
action_type="luma_create_event",
tool_name="create_luma_event",
params={
"name": name,
"start_at": start_at,
"end_at": end_at,
"description": description,
"timezone": timezone,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Event was not created.",
}
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
result = request_approval(
action_type="luma_create_event",
tool_name="create_luma_event",
params={
"name": name,
"start_at": start_at,
"end_at": end_at,
"description": description,
"timezone": timezone,
},
context={"connector_id": connector.id},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Luma Plus subscription required to create events via API.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}{resp.text[:200]}",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Event was not created.",
}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Luma Plus subscription required to create events via API.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}{resp.text[:200]}",
}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_luma_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_luma_events tool
"""
del db_session # per-call session — see docstring
@tool
async def list_luma_events(
max_results: int = 25,
@ -28,77 +47,80 @@ def create_list_luma_events_tool(
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
all_entries: list[dict] = []
cursor = None
all_entries: list[dict] = []
cursor = None
async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results:
params: dict[str, Any] = {
"limit": min(100, max_results - len(all_entries))
}
if cursor:
params["cursor"] = cursor
async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results:
params: dict[str, Any] = {
"limit": min(100, max_results - len(all_entries))
}
if cursor:
params["cursor"] = cursor
resp = await client.get(
f"{LUMA_API}/calendar/list-events",
headers=headers,
params=params,
resp = await client.get(
f"{LUMA_API}/calendar/list-events",
headers=headers,
params=params,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append(
{
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
}
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append(
{
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
}
)
return {"status": "success", "events": events, "total": len(events)}
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_luma_event tool
"""
del db_session # per-call session — see docstring
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
@ -26,60 +45,63 @@ def create_read_luma_event_tool(
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 404:
return {
"status": "not_found",
"message": f"Event '{event_id}' not found.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 404:
return {
"status": "not_found",
"message": f"Event '{event_id}' not found.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
return {"status": "success", "event": event_detail}
return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
"""
Factory function to create the create_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Notion API
blocking.
Args:
db_session: Database session for accessing Notion connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
Returns:
Configured create_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def create_notion_page(
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
"""
logger.info(f"create_notion_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@ -77,154 +88,157 @@ def create_create_notion_page_tool(
}
try:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {
"status": "error",
"message": context["error"],
}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Notion accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
logger.info(f"Requesting approval for creating Notion page: '{title}'")
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={
"title": title,
"content": content,
"parent_page_id": None,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Notion page creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_content = result.params.get("content", content)
final_parent_page_id = result.params.get("parent_page_id")
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
"message": "Page title cannot be empty. Please provide a valid title.",
}
logger.info(
f"Creating Notion page with final params: title='{final_title}'"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
connector = result.scalars().first()
if not connector:
logger.warning(
f"No Notion connector found for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Notion connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
if "error" in context:
logger.error(
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
f"Failed to fetch creation context: {context['error']}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
"message": context["error"],
}
logger.info(f"Validated Notion connector: id={actual_connector_id}")
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Notion accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
result = await notion_connector.create_page(
title=final_title,
content=final_content,
parent_page_id=final_parent_page_id,
)
logger.info(
f"create_page result: {result.get('status')} - {result.get('message', '')}"
)
logger.info(f"Requesting approval for creating Notion page: '{title}'")
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={
"title": title,
"content": content,
"parent_page_id": None,
"connector_id": connector_id,
},
context=context,
)
if result.get("status") == "success":
kb_message_suffix = ""
try:
from app.services.notion import NotionKBSyncService
if result.rejected:
logger.info("Notion page creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=result.get("page_id"),
page_title=result.get("title", final_title),
page_url=result.get("url"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
final_title = result.params.get("title", title)
final_content = result.params.get("content", content)
final_parent_page_id = result.params.get("parent_page_id")
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
"message": "Page title cannot be empty. Please provide a valid title.",
}
logger.info(
f"Creating Notion page with final params: title='{final_title}'"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
else:
)
connector = result.scalars().first()
if not connector:
logger.warning(
f"No Notion connector found for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Notion connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
logger.info(f"Validated Notion connector: id={actual_connector_id}")
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.create_page(
title=final_title,
content=final_content,
parent_page_id=final_parent_page_id,
)
logger.info(
f"create_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success":
kb_message_suffix = ""
try:
from app.services.notion import NotionKBSyncService
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=result.get("page_id"),
page_title=result.get("title", final_title),
page_url=result.get("url"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
result["message"] = result.get("message", "") + kb_message_suffix
result["message"] = result.get("message", "") + kb_message_suffix
return result
return result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
"""
Factory function to create the delete_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Database session for accessing Notion connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
Returns:
Configured delete_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_notion_page(
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
}
try:
# Get page context (page_id, account, title) from indexed data
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, page_title
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
async with async_session_maker() as db_session:
# Get page context (page_id, account, title) from indexed data
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, page_title
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
connector_id_from_context = account.get("id")
document_id = context.get("document_id")
logger.info(
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="notion_page_deletion",
tool_name="delete_notion_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Notion page deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
# Validate the connector
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
# Create connector instance
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
# Delete the page from Notion
result = await notion_connector.delete_page(page_id=final_page_id)
logger.info(
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
)
# If deletion was successful and user wants to delete from KB
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from sqlalchemy.future import select
from app.db import Document
# Get the document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
)
logger.error(f"Failed to fetch delete context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
# Update result with KB deletion status
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} (also removed from knowledge base)"
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
return result
page_id = context.get("page_id")
connector_id_from_context = account.get("id")
document_id = context.get("document_id")
logger.info(
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="notion_page_deletion",
tool_name="delete_notion_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Notion page deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
# Validate the connector
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
# Create connector instance
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
# Delete the page from Notion
result = await notion_connector.delete_page(page_id=final_page_id)
logger.info(
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
)
# If deletion was successful and user wants to delete from KB
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from sqlalchemy.future import select
from app.db import Document
# Get the document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
)
# Update result with KB deletion status
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} (also removed from knowledge base)"
)
return result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
"""
Factory function to create the update_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache (see
``create_create_notion_page_tool`` for the full rationale).
Args:
db_session: Database session for accessing Notion connector
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
Returns:
Configured update_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def update_notion_page(
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
@ -88,152 +96,155 @@ def create_update_notion_page_tool(
}
try:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
logger.info(
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
)
result = request_approval(
action_type="notion_page_update",
tool_name="update_notion_page",
params={
"page_id": page_id,
"content": content,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Notion page update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_content = result.params.get("content", content)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
logger.info(
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.update_page(
page_id=final_page_id,
content=final_content,
)
logger.info(
f"update_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService
logger.info(f"Updating knowledge base for document {document_id}...")
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
appended_content=final_content,
user_id=user_id,
search_space_id=search_space_id,
appended_block_ids=result.get("appended_block_ids"),
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title
)
if kb_result["status"] == "success":
result["message"] = (
f"{result['message']}. Your knowledge base has also been updated."
)
logger.info(
f"Knowledge base successfully updated for page {final_page_id}"
)
elif kb_result["status"] == "not_indexed":
result["message"] = (
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
)
else:
result["message"] = (
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
f"KB update failed for page {final_page_id}: {kb_result['message']}"
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
logger.info(
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
)
result = request_approval(
action_type="notion_page_update",
tool_name="update_notion_page",
params={
"page_id": page_id,
"content": content,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Notion page update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_content = result.params.get("content", content)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
logger.info(
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.update_page(
page_id=final_page_id,
content=final_content,
)
logger.info(
f"update_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService
logger.info(
f"Updating knowledge base for document {document_id}..."
)
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
appended_content=final_content,
user_id=user_id,
search_space_id=search_space_id,
appended_block_ids=result.get("appended_block_ids"),
)
return result
if kb_result["status"] == "success":
result["message"] = (
f"{result['message']}. Your knowledge base has also been updated."
)
logger.info(
f"Knowledge base successfully updated for page {final_page_id}"
)
elif kb_result["status"] == "not_indexed":
result["message"] = (
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
)
else:
result["message"] = (
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
)
logger.warning(
f"KB update failed for page {final_page_id}: {kb_result['message']}"
)
return result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_onedrive_file(
name: str,
@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
"""
logger.info(f"create_onedrive_file called: name='{name}'")
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected OneDrive accounts need re-authentication.",
"connector_type": "onedrive",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = OneDriveClient(session=db_session, connector_id=cid)
items, err = await client.list_children("root")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{"folder_id": item["id"], "name": item["name"]}
for item in items
if item.get("folder") is not None
and item.get("id")
and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
}
result = request_approval(
action_type="onedrive_file_creation",
tool_name="create_onedrive_file",
params={
"name": name,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_docx_extension(final_name)
if final_connector_id is not None:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
connectors = result.scalars().all()
if not connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid.",
if not connectors:
return {
"status": "error",
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected OneDrive accounts need re-authentication.",
"connector_type": "onedrive",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = OneDriveClient(session=db_session, connector_id=cid)
items, err = await client.list_children("root")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{"folder_id": item["id"], "name": item["name"]}
for item in items
if item.get("folder") is not None
and item.get("id")
and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
}
docx_bytes = _markdown_to_docx(final_content or "")
client = OneDriveClient(session=db_session, connector_id=connector.id)
created = await client.create_file(
name=final_name,
parent_id=final_parent_folder_id,
content=docx_bytes,
mime_type=DOCX_MIME,
)
logger.info(
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.onedrive import OneDriveKBSyncService
kb_service = OneDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=DOCX_MIME,
web_url=created.get("webUrl"),
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
result = request_approval(
action_type="onedrive_file_creation",
tool_name="create_onedrive_file",
params={
"name": name,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_url": created.get("webUrl"),
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_docx_extension(final_name)
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
if not connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid.",
}
docx_bytes = _markdown_to_docx(final_content or "")
client = OneDriveClient(session=db_session, connector_id=connector.id)
created = await client.create_file(
name=final_name,
parent_id=final_parent_folder_id,
content=docx_bytes,
mime_type=DOCX_MIME,
)
logger.info(
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.onedrive import OneDriveKBSyncService
kb_service = OneDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=DOCX_MIME,
web_url=created.get("webUrl"),
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_url": created.get("webUrl"),
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_onedrive_file(
file_name: str,
@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(
cast(
Document.document_metadata["onedrive_file_name"],
String,
)
)
== func.lower(file_name),
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
)
document = doc_result.scalars().first()
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed OneDrive files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
meta = document.document_metadata or {}
file_id = meta.get("onedrive_file_id")
document_id = document.id
if not file_id:
return {
"status": "error",
"message": "File ID is missing. Please re-index the file.",
}
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
if not document:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(
cast(
Document.document_metadata[
"onedrive_file_name"
],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
)
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "OneDrive connector not found or access denied.",
}
document = doc_result.scalars().first()
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "onedrive",
}
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed OneDrive files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
context = {
"file": {
"file_id": file_id,
"name": file_name,
"document_id": document_id,
"web_url": meta.get("web_url"),
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
result = request_approval(
action_type="onedrive_file_trash",
tool_name="delete_onedrive_file",
params={
"file_id": file_id,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
meta = document.document_metadata or {}
file_id = meta.get("onedrive_file_id")
document_id = document.id
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
if not file_id:
return {
"status": "error",
"message": "File ID is missing. Please re-index the file.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if final_connector_id != connector.id:
result = await db_session.execute(
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid or has been disconnected.",
"message": "OneDrive connector not found or access denied.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
)
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "onedrive",
}
client = OneDriveClient(
session=db_session, connector_id=actual_connector_id
)
await client.trash_file(final_file_id)
context = {
"file": {
"file_id": file_id,
"name": file_name,
"document_id": document_id,
"web_url": meta.get("web_url"),
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
logger.info(
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file_name}' to the recycle bin.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
result = request_approval(
action_type="onedrive_file_trash",
tool_name="delete_onedrive_file",
params={
"file_id": file_id,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
return trash_result
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
)
client = OneDriveClient(
session=db_session, connector_id=actual_connector_id
)
await client.trash_file(final_file_id)
logger.info(
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file_name}' to the recycle bin.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -824,13 +824,22 @@ async def build_tools_async(
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
This function exists because MCP tools require database queries to load user configs,
while built-in tools are created synchronously from static code.
This function exists because MCP tools require database queries to load
user configs, while built-in tools are created synchronously from static
code.
Alternative: We could make build_tools() itself async and always query the database,
but that would force async everywhere even when only using built-in tools. The current
design keeps the simple case (static tools only) synchronous while supporting dynamic
database-loaded tools through this async wrapper.
Alternative: We could make build_tools() itself async and always query
the database, but that would force async everywhere even when only using
built-in tools. The current design keeps the simple case (static tools
only) synchronous while supporting dynamic database-loaded tools through
this async wrapper.
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
the event loop) are kicked off concurrently. Cold-path savings are
bounded by the slower of the two typically MCP at ~200ms-1.7s
so the parallelization recovers the ~50-200ms previously spent
serially on built-in construction.
Args:
dependencies: Dict containing all possible dependencies
@ -843,33 +852,70 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
import asyncio
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
can_load_mcp = (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
)
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
# function over its inputs — safe to thread-shift.
_t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
builtin_task = asyncio.create_task(
asyncio.to_thread(
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
)
)
mcp_task: asyncio.Task | None = None
if can_load_mcp:
mcp_task = asyncio.create_task(
load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
)
# Surface failures from each task independently so a flaky MCP
# endpoint never poisons built-in tool registration. ``return_exceptions``
# gives us per-task exceptions instead of dropping the second result
# when the first raises.
if mcp_task is not None:
builtin_result, mcp_result = await asyncio.gather(
builtin_task, mcp_task, return_exceptions=True
)
else:
builtin_result = await builtin_task
mcp_result = None
if isinstance(builtin_result, BaseException):
raise builtin_result # built-in registration failure is non-recoverable
tools: list[BaseTool] = builtin_result
_perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
"[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(tools),
)
# Load MCP tools if requested and dependencies are available
if (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
):
try:
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
if mcp_task is not None:
if isinstance(mcp_result, BaseException):
# ``return_exceptions=True`` captures the exception out-of-band,
# so ``sys.exc_info()`` is empty here. Pass the captured
# exception via ``exc_info=`` to get a real traceback.
logging.error(
"Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
)
else:
mcp_tools = mcp_result or []
_perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
"[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(mcp_tools),
)
@ -879,8 +925,6 @@ async def build_tools_async(
len(mcp_tools),
[t.name for t in mcp_tools],
)
except Exception as e:
logging.exception("Failed to load MCP tools: %s", e)
logging.info(
"Total tools for agent: %d%s",

View file

@ -15,7 +15,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
"""
Factory function to create the search_surfsense_docs tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Database session for executing queries
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
A configured tool function for searching Surfsense documentation
"""
del db_session # per-call session — see docstring
@tool
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
Returns:
Relevant documentation content formatted with chunk IDs for citations
"""
return await search_surfsense_docs_async(
query=query,
db_session=db_session,
top_k=top_k,
)
async with async_session_maker() as db_session:
return await search_surfsense_docs_async(
query=query,
db_session=db_session,
top_k=top_k,
)
return search_surfsense_docs

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_teams_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_teams_channels tool
"""
del db_session # per-call session — see docstring
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(
f"{GRAPH_API}/me/joinedTeams", headers=headers
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
if teams_resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if teams_resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {teams_resp.status_code}",
}
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append(
{
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(
f"{GRAPH_API}/me/joinedTeams", headers=headers
)
return {
"status": "success",
"teams": result_teams,
"total_teams": len(result_teams),
}
if teams_resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if teams_resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {teams_resp.status_code}",
}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append(
{
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
}
)
return {
"status": "success",
"teams": result_teams,
"total_teams": len(result_teams),
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_teams_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_teams_messages tool
"""
del db_session # per-call session — see docstring
@tool
async def read_teams_messages(
team_id: str,
@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Insufficient permissions to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}",
}
token = await get_access_token(db_session, connector)
raw_msgs = resp.json().get("value", [])
messages = []
for m in raw_msgs:
sender = m.get("from", {})
user_info = sender.get("user", {}) if sender else {}
body = m.get("body", {})
messages.append(
{
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Insufficient permissions to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}",
}
)
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
raw_msgs = resp.json().get("value", [])
messages = []
for m in raw_msgs:
sender = m.get("from", {})
user_info = sender.get("user", {}) if sender else {}
body = m.get("body", {})
messages.append(
{
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
}
)
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_teams_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_teams_message tool
"""
del db_session # per-call session — see docstring
@tool
async def send_teams_message(
team_id: str,
@ -39,70 +57,73 @@ def create_send_teams_message_tool(
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
result = request_approval(
action_type="teams_send_message",
tool_name="send_teams_message",
params={
"team_id": team_id,
"channel_id": channel_id,
"content": content,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
result = request_approval(
action_type="teams_send_message",
tool_name="send_teams_message",
params={
"team_id": team_id,
"channel_id": channel_id,
"content": content,
},
json={"body": {"content": final_content}},
context={"connector_id": connector.id},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}{resp.text[:200]}",
}
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": "Message sent to Teams channel.",
}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={"body": {"content": final_content}},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}{resp.text[:200]}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": "Message sent to Teams channel.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt

View file

@ -26,7 +26,7 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User
from app.db import SearchSpace, User, async_session_maker
from app.utils.content_utils import extract_text_content
logger = logging.getLogger(__name__)
@ -302,6 +302,25 @@ def create_update_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the user-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
user_id: ID of the user whose memory document is being updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the user-memory scope.
"""
del db_session # per-call session — see docstring
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
@ -318,26 +337,26 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
async with async_session_maker() as db_session:
result = await db_session.execute(select(User).where(User.id == uid))
user = result.scalars().first()
if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md
old_memory = user.memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="memory",
scope="user",
)
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update memory: {e}",
@ -351,6 +370,27 @@ def create_update_team_memory_tool(
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory function to create the team-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
search_space_id: ID of the search space whose team memory is being
updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the team-memory scope.
"""
del db_session # per-call session — see docstring
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
@ -366,28 +406,30 @@ def create_update_team_memory_tool(
updated_memory: The FULL updated markdown document (not a diff).
"""
try:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
space = result.scalars().first()
if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md
old_memory = space.shared_memory_md
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
return await _save_memory(
updated_memory=updated_memory,
old_memory=old_memory,
llm=llm,
apply_fn=lambda content: setattr(
space, "shared_memory_md", content
),
commit_fn=db_session.commit,
rollback_fn=db_session.rollback,
label="team memory",
scope="team",
)
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"message": f"Failed to update team memory: {e}",

View file

@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
OpenRouterIntegrationService.get_instance().stop_background_refresh()
async def _warm_agent_jit_caches() -> None:
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
Why
----
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
generation chain takes 1.5-2 seconds of pure CPU on first invocation
inside any Python process: the graph compiler builds reducers,
Pydantic v2 generates and JITs validator schemas, deepagents
eagerly compiles its general-purpose subagent, etc. Subsequent
compiles in the same process pay only ~50% of that cost (the lazy
JIT bits are cached in module-level dicts).
Doing one throwaway compile during ``lifespan`` startup pre-pays
that cost so the *first real request* doesn't. We do NOT prime
:mod:`agent_cache` because the cache key requires real
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. the
throwaway agent is genuinely thrown away and immediately collected.
Safety
------
* No DB access. We construct a stub LLM (no real keys), pass an
empty tools list, and pass ``checkpointer=None`` so we never
touch Postgres.
* Bounded by ``asyncio.wait_for`` so a hang here can never block
worker startup. On any failure, we log + swallow the worst
case is the first real request pays the full cold cost (i.e.
pre-warmup behaviour).
"""
import time as _time
logger = logging.getLogger(__name__)
t0 = _time.perf_counter()
try:
from langchain.agents import create_agent
from langchain.agents.middleware import (
ModelCallLimitMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
)
from langchain_core.tools import tool
from app.agents.new_chat.context import SurfSenseContextSchema
# Minimal LLM stub. ``FakeListChatModel`` satisfies
# ``BaseChatModel`` without any network or auth — perfect for
# exercising the compile path without side effects.
stub_llm = FakeListChatModel(responses=["warmup-response"])
# Two trivial tools with arg + return schemas — exercises the
# Pydantic v2 schema JIT path. Without at least one tool the
# graph compile skips the tool-loop bytecode generation that
# accounts for ~30-50% of cold compile cost.
@tool
def _warmup_tool_a(query: str, limit: int = 5) -> str:
"""Warmup tool A — never actually invoked."""
return query[:limit]
@tool
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
"""Warmup tool B — never actually invoked."""
return {"name": name, "value": value}
# A handful of common middleware so the compile pre-pays the
# ``AgentMiddleware`` resolver path. These instances never run
# because the throwaway agent is immediately collected.
# ``SubAgentMiddleware`` is the single heaviest line in cold
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
# compile its general-purpose subagent's full inner graph),
# so we include it here to make sure that compile path is JIT'd.
warmup_middleware: list = [
TodoListMiddleware(),
ModelCallLimitMiddleware(
thread_limit=120, run_limit=80, exit_behavior="end"
),
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
),
]
try:
from deepagents import SubAgentMiddleware
from deepagents.backends import StateBackend
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
gp_warmup_spec = { # type: ignore[var-annotated]
**GENERAL_PURPOSE_SUBAGENT,
"model": stub_llm,
"tools": [_warmup_tool_a],
"middleware": [TodoListMiddleware()],
}
warmup_middleware.append(
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
)
except Exception:
# Deepagents missing/incompatible — middleware-only warmup
# still produces a useful (smaller) speedup.
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
compiled = create_agent(
stub_llm,
tools=[_warmup_tool_a, _warmup_tool_b],
system_prompt="You are a warmup stub.",
middleware=warmup_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=None,
)
# Touch the compiled graph's stream_channels / nodes so any
# remaining lazy schema work fires now instead of on first
# real invocation.
_ = list(getattr(compiled, "nodes", {}).keys())
del compiled
logger.info(
"[startup] Agent JIT warmup completed in %.3fs",
_time.perf_counter() - t0,
)
except Exception:
logger.warning(
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
"real request will pay the full compile cost)",
_time.perf_counter() - t0,
exc_info=True,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
@ -432,6 +562,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables()
initialize_openrouter_integration()
_start_openrouter_background_refresh()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
@ -443,6 +574,18 @@ async def lifespan(app: FastAPI):
"Docs will be indexed on the next restart."
)
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
# worker readiness. ``shield`` so Uvicorn cancelling startup
# doesn't leave half-warmed Pydantic schemas in an inconsistent
# state.
try:
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
except (TimeoutError, Exception): # pragma: no cover - defensive
logging.getLogger(__name__).warning(
"[startup] Agent JIT warmup hit timeout/error — skipping; "
"first real request will pay the full compile cost."
)
log_system_snapshot("startup_complete")
yield
@ -452,6 +595,23 @@ async def lifespan(app: FastAPI):
def registration_allowed():
"""Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
Despite the name, this dependency does NOT only gate registration. When
REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
that could mint or refresh a session for an attacker:
* email/password ``POST /auth/register``
* email/password ``POST /auth/jwt/login``
* the Google OAuth router (``/auth/google/authorize`` and the shared
``/auth/google/callback`` handles both new signups and login for
existing users, so flipping this off locks both)
* the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
Use it as a temporary "freeze all new sessions" lever during incident
response. It is not a way to disable signup while keeping login working;
for that, override ``UserManager.oauth_callback`` instead.
"""
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
@ -596,32 +756,45 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
app.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth/jwt",
tags=["auth"],
dependencies=[Depends(rate_limit_login)],
)
app.include_router(
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
dependencies=[
Depends(rate_limit_register),
Depends(registration_allowed), # blocks registration when disabled
],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
dependencies=[Depends(rate_limit_password_reset)],
)
app.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
# Password / email-based auth routers are only mounted when not running in
# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
# POST /auth/register reachable, which is the bypass that allowed bots to
# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
if config.AUTH_TYPE != "GOOGLE":
app.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth/jwt",
tags=["auth"],
dependencies=[
Depends(rate_limit_login),
Depends(
registration_allowed
), # honour REGISTRATION_ENABLED kill switch on login too
],
)
app.include_router(
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
dependencies=[
Depends(rate_limit_register),
Depends(registration_allowed),
],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
dependencies=[Depends(rate_limit_password_reset)],
)
app.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
# /users/me (read/update profile) is needed in every auth mode, so it stays
# mounted unconditionally.
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
),
prefix="/auth/google",
tags=["auth"],
dependencies=[
Depends(registration_allowed)
], # blocks OAuth registration when disabled
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
# it blocks BOTH new OAuth signups AND login of existing OAuth users
# (the fastapi-users OAuth router shares one callback for create+login,
# so this dependency closes both paths together).
dependencies=[Depends(registration_allowed)],
)
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON
# which fixes cross-site cookie issues where browsers don't send cookies
# set via cross-origin fetch requests on subsequent redirects
@app.get("/auth/google/authorize-redirect", tags=["auth"])
# set via cross-origin fetch requests on subsequent redirects.
# The registration_allowed dependency mirrors the OAuth router above so
# the kill switch fails fast here instead of bouncing users to Google
# only to 403 on the callback.
@app.get(
"/auth/google/authorize-redirect",
tags=["auth"],
dependencies=[Depends(registration_allowed)],
)
async def google_authorize_redirect(
request: Request,
):

View file

@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router,
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()

View file

@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f)
configs = data.get("global_llm_configs", [])
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {}
for cfg in configs:
cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False)
# Capability flag: explicit YAML override always wins. When the
# operator has not annotated the model, defer to LiteLLM's
# authoritative model map (`supports_vision`) which already
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
# vision-capable. Unknown / unmapped models default-allow so
# we don't lock the user out of a freshly added third-party
# entry; the streaming-task safety net (driven by
# `is_known_text_only_chat_model`) is the only place a False
# actually blocks a request.
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"]
@ -138,7 +164,11 @@ def load_global_image_gen_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_image_generation_configs", [])
configs = data.get("global_image_generation_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}")
return []
@ -153,7 +183,11 @@ def load_global_vision_llm_configs():
try:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("global_vision_llm_configs", [])
configs = data.get("global_vision_llm_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
@ -254,6 +288,15 @@ def load_openrouter_integration_settings() -> dict | None:
"anonymous_enabled_free", settings["anonymous_enabled"]
)
# Image generation + vision LLM emission are opt-in (issue L).
# OpenRouter's catalogue contains hundreds of image / vision
# capable models; auto-injecting all of them into every
# deployment would explode the model selector and surprise
# operators upgrading from prior versions. Default to False so
# admins must explicitly turn them on.
settings.setdefault("image_generation_enabled", False)
settings.setdefault("vision_enabled", False)
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
@ -296,10 +339,60 @@ def initialize_openrouter_integration():
)
else:
print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
image_configs = service.get_image_generation_configs()
if image_configs:
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
print(
f"Info: OpenRouter integration added {len(image_configs)} "
f"image-generation models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
def initialize_pricing_registration():
"""
Teach LiteLLM the per-token cost of every deployment in
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
from the OpenRouter catalogue + any operator-declared YAML pricing).
Must run AFTER ``initialize_openrouter_integration()`` so the
OpenRouter catalogue is populated and BEFORE the first LLM call so
``response_cost`` is available in ``TokenTrackingCallback``.
Failures are logged but never raised startup must not be blocked
by a missing pricing entry; the worst-case is the model debits 0.
"""
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as e:
print(f"Warning: Failed to register LiteLLM pricing: {e}")
def initialize_llm_router():
"""
Initialize the LLM Router service for Auto mode.
@ -444,14 +537,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
# Premium token quota settings
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
# Premium credit (micro-USD) quota settings.
#
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
# still honoured for one release as fall-back values — the prior
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
# to micros, so operators upgrading without changing their .env still
# get correct behaviour. A startup deprecation warning fires below if
# they're set.
PREMIUM_CREDIT_MICROS_LIMIT = int(
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
)
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
STRIPE_CREDIT_MICROS_PER_UNIT = int(
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
)
STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
)
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
# config's ``quota_reserve_tokens`` and clamps the result to this value
# so a misconfigured "$1000/M" model can't lock the user's whole balance
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
# reserve_tokens ≈ $0.36) with headroom.
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
"PREMIUM_CREDIT_MICROS_LIMIT"
):
print(
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
"current Stripe price). The old key will be removed in a "
"future release."
)
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
"STRIPE_CREDIT_MICROS_PER_UNIT"
):
print(
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
"The old key will be removed in a future release."
)
# Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
@ -464,6 +597,35 @@ class Config:
# Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
# ``POST /image-generations`` endpoint when the global config does not
# override it. $0.05 covers realistic worst-cases for current OpenAI /
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
)
# Per-podcast reservation (in micro-USD). One agent LLM call generating
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
# premium-model run. Tune via env.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
)
# Per-video-presentation reservation (in micro-USD). Fan-out of N
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
# plus refine retries; can produce many premium completions. $1.00
# covers worst-case. Tune via env.
#
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
# 1_000_000. The override path in ``billable_call`` bypasses the
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
# *actual* hold — raising it via env is fine but means a single video
# task can lock $1+ of credit.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
)
# Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int(

View file

@ -19,6 +19,24 @@
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
#
# COST-BASED PREMIUM CREDITS:
# Each premium config bills the user's USD-credit balance based on the
# actual provider cost reported by LiteLLM. For models LiteLLM already
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
# or any model LiteLLM doesn't have in its built-in pricing table, declare
# per-token costs inline so they bill correctly:
#
# litellm_params:
# base_model: "my-custom-azure-deploy"
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
# input_cost_per_token: 0.000003
# output_cost_per_token: 0.000015
#
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
# API — no inline declaration needed. Models without resolvable pricing
# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
@ -292,6 +310,17 @@ openrouter_integration:
free_rpm: 20
free_tpm: 100000
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
# contains hundreds of image- and vision-capable models; turning these on
# injects them into the global Image-Generation / Vision-LLM model
# selectors alongside any static configs. Tier (free/premium) is derived
# per model the same way it is for chat (`:free` suffix or zero pricing).
# When a user picks a premium image/vision model the call debits the
# shared $5 USD-cost-based premium credit pool — so leaving these off
# avoids surprise quota burn on existing deployments. Default: false.
image_generation_enabled: false
vision_enabled: false
litellm_params:
max_tokens: 16384
system_instructions: ""

View file

@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True)
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin):
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
Note: the table name is preserved (``premium_token_purchases``) for
operational continuity even though the unit is now USD micro-credits
instead of raw tokens. The ``credit_micros_granted`` column replaced
the legacy ``tokens_granted`` in migration 140; the stored values
were not transformed because the prior $1 = 1M tokens Stripe price
makes the unit conversion 1:1 numerically.
"""
__tablename__ = "premium_token_purchases"
__allow_unmapped__ = True
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
)
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False)
tokens_granted = Column(BigInteger, nullable=False)
credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True)
status = Column(
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
premium_tokens_used = Column(
premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
@ -2241,16 +2250,16 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column(
premium_credit_micros_limit = Column(
BigInteger,
nullable=False,
default=config.PREMIUM_TOKEN_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT),
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
)
premium_tokens_used = Column(
premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0"
)
premium_tokens_reserved = Column(
premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0"
)

View file

@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM",
content_type="image",
)
except Exception:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
except Exception as exc:
# Special-case quota exhaustion so we log a clearer message
# — the vision LLM didn't "fail", the user just ran out of
# premium credit. Falling through to the document parser
# is a graceful degradation: OCR/Unstructured still
# extracts text from the image without burning credit.
from app.services.billable_calls import QuotaInsufficientError
if isinstance(exc, QuotaInsufficientError):
logging.info(
"Vision LLM quota exhausted for %s; falling back to document parser",
request.filename,
)
else:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
else:
logging.info(
"No vision LLM provided, falling back to document parser for %s",

View file

@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.config import config
from app.db import User
from app.users import current_active_user
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool
enable_desktop_local_filesystem: bool
@classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
return cls(**asdict(flags))
return cls(
**asdict(flags),
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
)
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)

View file

@ -649,13 +649,9 @@ async def list_composio_drive_folders(
"""
List folders AND files in user's Google Drive via Composio.
Uses the same GoogleDriveClient / list_folder_contents path as the native
connector, with Composio-sourced credentials. This means auth errors
propagate identically (Google returns 401 exception auth_expired flag).
Uses Composio's Google Drive tool execution path so managed OAuth tokens
do not need to be exposed through connected account state.
"""
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
from app.utils.google_credentials import build_composio_credentials
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.",
)
credentials = build_composio_credentials(composio_connected_account_id)
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
service = ComposioService()
entity_id = f"surfsense_{user.id}"
items = []
page_token = None
error = None
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
while True:
page_items, next_token, page_error = await service.get_drive_files(
connected_account_id=composio_connected_account_id,
entity_id=entity_id,
folder_id=parent_id,
page_token=page_token,
page_size=100,
)
if page_error:
error = page_error
break
items.extend(page_items)
if not next_token:
break
page_token = next_token
for item in items:
item["isFolder"] = (
item.get("mimeType") == "application/vnd.google-apps.folder"
)
items.sort(
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
)
if error:
error_lower = error.lower()

View file

@ -36,11 +36,17 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation(
@ -138,12 +192,18 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
model_string = _build_model_string(
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -165,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
model_string = _build_model_string(
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@ -454,7 +531,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request."""
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios clients already know how to surface those).
"""
try:
await check_permission(
session,
@ -471,33 +567,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_generation_config_id, search_space
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
# propagates to the outer ``except QuotaInsufficientError`` handler
# below as HTTP 402 — it is intentionally NOT swallowed into
# ``error_message`` because that would (1) imply a successful row
# exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException:
raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"used_micros": exc.used_micros,
"limit_micros": exc.limit_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of premium credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(

View file

@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages)
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,

View file

@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
There is no DB column for ``supports_image_input`` the value is
resolved at the API boundary from LiteLLM's authoritative model map
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
the response shape consistent across list / detail / create / update
endpoints without having to remember to set the field at every call
site.
"""
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
)
# ``model_validate`` runs the Pydantic conversion using the ORM
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
# then we layer the derived field on. ``model_copy(update=...)`` keeps
# the surface immutable from the caller's perspective.
base_read = NewLLMConfigRead.model_validate(config)
return base_read.model_copy(update={"supports_image_input": supports_image_input})
# =============================================================================
# Global Configs Routes
# =============================================================================
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
# Auto routes across the configured pool, which usually
# includes at least one vision-capable deployment, so
# treat Auto as image-capable. The router itself will
# still pick a vision-capable deployment for messages
# carrying image_url blocks (LiteLLM Router falls back
# on ``404`` per its ``allowed_fails`` policy).
"supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
# Capability resolution: explicit value (YAML override or OR
# `_supports_image_input(model)` payload baked in by the
# OpenRouter integration service) wins. Fall back to the
# LiteLLM-driven helper which default-allows on unknown so
# we don't hide vision-capable models that happen to lack a
# YAML annotation. The streaming task safety net is the
# only place a False ever blocks.
if "supports_image_input" in cfg:
supports_image_input = bool(cfg.get("supports_image_input"))
else:
cfg_litellm_params = cfg.get("litellm_params") or {}
cfg_base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
)
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit()
await session.refresh(db_config)
return db_config
return _serialize_byok_config(db_config)
except HTTPException:
raise
@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit)
)
return result.scalars().all()
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space",
)
return config
return _serialize_byok_config(config)
except HTTPException:
raise
@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit()
await session.refresh(config)
return config
return _serialize_byok_config(config)
except HTTPException:
raise

View file

@ -591,6 +591,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -607,6 +608,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None
@ -649,6 +651,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -665,6 +668,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None

View file

@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
# Read the new metadata key first, fall back to the legacy one so
# in-flight checkout sessions created before the cost-credits
# release still fulfil correctly (the unit is numerically the
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
credit_micros_per_unit = int(
metadata.get("credit_micros_per_unit")
or metadata.get("tokens_per_unit", "0")
)
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
tokens_granted=quantity * tokens_per_unit,
credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
user.premium_tokens_limit = (
max(user.premium_tokens_used, user.premium_tokens_limit)
+ purchase.tokens_granted
# Top up the user's credit balance by the granted micro-USD amount.
# ``max(used, limit)`` clamps the case where the legacy code wrote a
# used value above the limit (e.g. underbilling rounding) so adding
# ``credit_micros_granted`` always lifts the limit by the full pack
# size rather than disappearing into past overuse.
user.premium_credit_micros_limit = (
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ purchase.credit_micros_granted
)
await db_session.commit()
@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
"""Create a Stripe Checkout Session for buying premium token packs."""
"""Create a Stripe Checkout Session for buying premium credit packs.
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
credit (default 1_000_000 = $1.00). The user's balance is debited
at the actual provider cost reported by LiteLLM at finalize time,
so $1 of credit always buys $1 worth of provider usage at cost.
"""
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
"purchase_type": "premium_tokens",
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
"purchase_type": "premium_credit",
},
}
)
@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
tokens_granted=tokens_granted,
credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
"""Return token-buying availability and current premium quota for frontend."""
used = user.premium_tokens_used
limit = user.premium_tokens_limit
"""Return token-buying availability and current premium credit quota for frontend.
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
when displaying. The route name is preserved for back-compat with
pinned client deployments.
"""
used = user.premium_credit_micros_used
limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
premium_tokens_used=used,
premium_tokens_limit=limit,
premium_tokens_remaining=max(0, limit - used),
premium_credit_micros_used=used,
premium_credit_micros_limit=limit,
premium_credit_micros_remaining=max(0, limit - used),
)

View file

@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
}
)

View file

@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(
@ -231,3 +237,24 @@ class GlobalImageGenConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_micros: int | None = Field(
default=None,
description=(
"Optional override for the reservation amount (in micro-USD) used when "
"this image generation is premium. Falls back to "
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
),
)

View file

@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
cost_micros: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (no DB column). Default
# True matches the conservative-allow stance — a BYOK row that the
# route forgot to augment is not pre-judged. The streaming-task
# safety net is the only place a False actually blocks a request.
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map "
"(``litellm.supports_vision``) — there is no DB column. "
"Default True is the conservative-allow stance for unknown / "
"unmapped models."
),
)
model_config = ConfigDict(from_attributes=True)
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (see NewLLMConfigRead).
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map. "
"Default True is the conservative-allow stance."
),
)
model_config = ConfigDict(from_attributes=True)
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
supports_image_input: bool = Field(
default=True,
description=(
"Whether the model accepts image inputs (multimodal vision). "
"Derived server-side: OpenRouter dynamic configs use "
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
"authoritative model map (``litellm.supports_vision``). The "
"new-chat selector hints with a 'No image' badge when this is "
"False and there are pending image attachments. The streaming "
"task fails fast only when LiteLLM *explicitly* marks a model "
"as text-only — unknown / unmapped models default-allow."
),
)
# =============================================================================

View file

@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
class TokenPurchaseRead(BaseModel):
"""Serialized premium token purchase record."""
"""Serialized premium credit purchase record.
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
schema name kept ``Token`` for API back-compat with pinned clients.
"""
id: uuid.UUID
stripe_checkout_session_id: str
stripe_payment_intent_id: str | None = None
quantity: int
tokens_granted: int
credit_micros_granted: int
amount_total: int | None = None
currency: str | None = None
status: str
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
class TokenPurchaseHistoryResponse(BaseModel):
"""Response containing the user's premium token purchases."""
"""Response containing the user's premium credit purchases."""
purchases: list[TokenPurchaseRead]
class TokenStripeStatusResponse(BaseModel):
"""Response describing token-buying availability and current quota."""
"""Response describing premium-credit-buying availability and balance.
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
divides by 1_000_000 to display USD.
"""
token_buying_enabled: bool
premium_tokens_used: int = 0
premium_tokens_limit: int = 0
premium_tokens_remaining: int = 0
premium_credit_micros_used: int = 0
premium_credit_micros_limit: int = 0
premium_credit_micros_remaining: int = 0

View file

@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
class GlobalVisionLLMConfigRead(BaseModel):
"""Schema for reading global vision LLM configs from YAML.
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(...)
name: str
description: str | None = None
@ -73,3 +82,35 @@ class GlobalVisionLLMConfigRead(BaseModel):
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_tokens: int | None = Field(
default=None,
description=(
"Optional override for the per-call reservation in *tokens* — "
"converted to micro-USD via the model's input/output prices at "
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
),
)
input_cost_per_token: float | None = Field(
default=None,
description=(
"Optional input price in USD/token. Used by pricing_registration to "
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
),
)
output_cost_per_token: float | None = Field(
default=None,
description="Optional output price in USD/token. Pair with input_cost_per_token.",
)

View file

@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
_healthy_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]:
def _cfg_supports_image_input(cfg: dict) -> bool:
"""True if the global cfg can accept image inputs.
Prefers the explicit ``supports_image_input`` flag (set by the YAML
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
a YAML entry whose flag was somehow stripped doesn't get wrongly
excluded. Default-allows on unknown the streaming-task safety net
is the actual block, not this filter.
"""
if "supports_image_input" in cfg:
return bool(cfg.get("supports_image_input"))
# Lazy import: provider_capabilities -> llm_config -> services chain;
# importing at module load would create an init-order cycle through
# ``app.config``.
from app.services.provider_capabilities import derive_supports_image_input
cfg_litellm_params = cfg.get("litellm_params") or {}
base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
return derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
When ``requires_image_input`` is True (image turn), additionally
filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request.
"""
candidates = [
cfg
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
and (not requires_image_input or _cfg_supports_image_input(cfg))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
@ -185,6 +220,15 @@ def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model."""
return (
_tier_of(cfg) == "premium"
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
@ -237,11 +281,20 @@ async def resolve_or_get_pinned_llm_config_id(
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
requires_image_input: bool = False,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
When ``requires_image_input`` is True (the current turn carries an
``image_url`` block), the candidate pool is filtered to vision-capable
cfgs and any existing pin that can't accept image input is treated as
invalid (force re-pin). If no vision-capable cfg is available the
function raises ``ValueError`` so the streaming task surfaces the same
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
silently routing the image to a text-only deployment.
"""
thread = (
(
@ -274,14 +327,24 @@ async def resolve_or_get_pinned_llm_config_id(
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
c
for c in _global_candidates(requires_image_input=requires_image_input)
if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError(
"No vision-capable global LLM configs are available for Auto mode"
)
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
# tier switch), unless the caller explicitly requests a forced repin to free.
# tier switch), unless the caller explicitly requests a forced repin to free
# *or* the turn requires image input but the pin can't handle it.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
@ -311,6 +374,29 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=True,
)
if pinned_id is not None:
# If the pin is *only* invalid because it can't handle the image
# turn (it's still a healthy, usable config in the broader pool),
# log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown.
if requires_image_input:
try:
pinned_global = next(
c
for c in config.GLOBAL_LLM_CONFIGS
if int(c.get("id", 0)) == int(pinned_id)
)
except StopIteration:
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
@ -322,11 +408,19 @@ async def resolve_or_get_pinned_llm_config_id(
False if force_repin_free else await _is_premium_eligible(session, user_id)
)
if premium_eligible:
eligible = candidates
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
preferred_premium = [
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
]
eligible = preferred_premium or premium_candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible:
if requires_image_input:
raise ValueError(
"Auto mode could not find a vision-capable LLM config for this user and quota state"
)
raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state"
)

View file

@ -0,0 +1,566 @@
"""
Per-call billable wrapper for image generation, vision LLM extraction, and
any other short-lived premium operation that must charge against the user's
shared premium credit pool.
The ``billable_call`` async context manager encapsulates the standard
"reserve → execute → finalize / release → record audit row" lifecycle in a
single primitive so callers (the image-generation REST route and the
vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes no caller transaction.
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
inside their own session context. Route callers use
``shielded_async_session()`` by default; Celery callers can provide a
worker-loop-safe session factory. This guarantees that quota
commit/rollback can never accidentally flush or roll back rows the caller
has staged in its main session (e.g. a freshly-created
``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
nested ``billable_call`` inside an outer chat turn cannot corrupt the
chat turn's accumulator.
3. **Free configs are still audited.** Free calls bypass the reserve /
finalize dance entirely but still record a ``TokenUsage`` audit row with
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
pipeline complete for analytics even when nothing is debited.
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
responsible for translating that into HTTP 402. We *do not* catch the
denial inside ``billable_call`` letting it propagate also prevents
the image-generation route from creating an ``ImageGeneration`` row
for a request that never actually ran.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import shielded_async_session
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
record_token_usage,
scoped_turn,
)
logger = logging.getLogger(__name__)
AUDIT_TIMEOUT_SECONDS = 10.0
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
{"video_presentation_generation", "podcast_generation"}
)
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
call because the user has exhausted their premium credit pool.
The route handler should catch this and return HTTP 402 Payment
Required (or the equivalent for the surface area). Outside of the HTTP
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
callers may catch this and degrade gracefully e.g. fall back to OCR
when vision is unavailable.
"""
def __init__(
self,
*,
usage_type: str,
used_micros: int,
limit_micros: int,
remaining_micros: int,
) -> None:
self.usage_type = usage_type
self.used_micros = used_micros
self.limit_micros = limit_micros
self.remaining_micros = remaining_micros
super().__init__(
f"Premium credit exhausted for {usage_type}: "
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
)
class BillingSettlementError(Exception):
"""Raised when a premium call completed but credit settlement failed."""
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
self.usage_type = usage_type
self.user_id = user_id
super().__init__(
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
)
async def _rollback_safely(session: AsyncSession) -> None:
rollback = getattr(session, "rollback", None)
if rollback is not None:
with suppress(Exception):
await rollback()
async def _record_audit_best_effort(
*,
session_factory: BillableSessionFactory,
usage_type: str,
search_space_id: int,
user_id: UUID,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int,
model_breakdown: dict[str, Any],
call_details: dict[str, Any] | None,
thread_id: int | None,
message_id: int | None,
audit_label: str,
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> None:
"""Persist a TokenUsage row without letting audit failure block callers.
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
audit insert or commit hangs, user-facing artifacts such as videos and
podcasts must still be able to transition to READY after settlement.
"""
audit_thread_id = (
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
)
async def _persist() -> None:
logger.info(
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
async with session_factory() as audit_session:
try:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=audit_thread_id,
message_id=message_id,
)
logger.info(
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
await audit_session.commit()
logger.info(
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
except BaseException:
await _rollback_safely(audit_session)
raise
try:
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
except TimeoutError:
logger.warning(
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
"timeout=%.1fs total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
timeout_seconds,
total_tokens,
cost_micros,
)
except Exception:
logger.exception(
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
@asynccontextmanager
async def billable_call(
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None = None,
quota_reserve_micros_override: int | None = None,
usage_type: str,
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
billable_session_factory: BillableSessionFactory | None = None,
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
Args:
user_id: Owner of the credit pool to debit. For vision-LLM during
indexing this is the *search-space owner* (issue M), not the
triggering user.
search_space_id: Required recorded on the ``TokenUsage`` audit row.
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
the reserve/finalize dance but still records an audit row with
the captured cost.
base_model: Used by :func:`estimate_call_reserve_micros` to compute
a worst-case reservation from LiteLLM's pricing table.
quota_reserve_tokens: Optional per-config override for the chat-style
reserve estimator (vision LLM uses this).
quota_reserve_micros_override: Optional flat micro-USD reservation
(image generation uses this its cost shape is per-image, not
per-token).
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
Recorded on the ``TokenUsage`` row.
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
billable_session_factory: Optional async context factory used for
reserve/finalize/release/audit sessions. Defaults to
``shielded_async_session`` for route callers; Celery callers pass
a worker-loop-safe session factory.
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
Audit failure is best-effort and does not undo successful
settlement.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
the underlying LLM/image API while inside the ``async with``; the
``TokenTrackingCallback`` populates the accumulator automatically.
Raises:
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
session_factory = billable_session_factory or shielded_async_session
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
if not is_premium:
try:
yield acc
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
await _record_audit_best_effort(
session_factory=session_factory,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
audit_label="free",
timeout_seconds=audit_timeout_seconds,
)
return
# ---------- Premium path: reserve → execute → finalize ----------
if quota_reserve_micros_override is not None:
reserve_micros = max(1, int(quota_reserve_micros_override))
else:
reserve_micros = estimate_call_reserve_micros(
base_model=base_model or "",
quota_reserve_tokens=quota_reserve_tokens,
)
request_id = str(uuid4())
async with session_factory() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
reserve_micros=reserve_micros,
)
if not reserve_result.allowed:
logger.info(
"[billable_call] reserve DENIED user=%s usage_type=%s "
"reserve=%d used=%d limit=%d remaining=%d",
user_id,
usage_type,
reserve_micros,
reserve_result.used,
reserve_result.limit,
reserve_result.remaining,
)
raise QuotaInsufficientError(
usage_type=usage_type,
used_micros=reserve_result.used,
limit_micros=reserve_result.limit,
remaining_micros=reserve_result.remaining,
)
logger.info(
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
"(remaining=%d)",
user_id,
usage_type,
reserve_micros,
reserve_result.remaining,
)
try:
yield acc
except BaseException:
# Release on any failure (including QuotaInsufficientError raised
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] premium_release failed for user=%s "
"reserve_micros=%d (reservation will be GC'd by quota "
"reconciliation if/when implemented)",
user_id,
reserve_micros,
)
raise
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
logger.info(
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
"reserved=%d thread=%s",
user_id,
usage_type,
actual_micros,
reserve_micros,
thread_id,
)
async with session_factory() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
request_id=request_id,
actual_micros=actual_micros,
reserved_micros=reserve_micros,
)
logger.info(
"[billable_call] finalize user=%s usage_type=%s actual=%d "
"reserved=%d → used=%d/%d (remaining=%d)",
user_id,
usage_type,
actual_micros,
reserve_micros,
final_result.used,
final_result.limit,
final_result.remaining,
)
except Exception as finalize_exc:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
"[billable_call] premium_finalize failed for user=%s; "
"attempting release",
user_id,
)
try:
async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
reserved_micros=reserve_micros,
)
except Exception:
logger.exception(
"[billable_call] release after finalize failure ALSO failed "
"for user=%s",
user_id,
)
raise BillingSettlementError(
usage_type=usage_type,
user_id=user_id,
cause=finalize_exc,
) from finalize_exc
await _record_audit_best_effort(
session_factory=session_factory,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
audit_label="premium",
timeout_seconds=audit_timeout_seconds,
)
async def _resolve_agent_billing_for_search_space(
session: AsyncSession,
search_space_id: int,
*,
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
* ``thread_id`` is set: delegate to
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
recurse into the resolved id. Reuses chat's existing pin if present
so the same model bills for chat + downstream podcast/video. If the
user is not premium-eligible, the pin service auto-restricts to free
deployments denial only happens later in
``billable_call.premium_reserve`` if the pin really is premium and
credit ran out mid-flow.
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
for any future direct-API path; today both Celery tasks always pass
``thread_id``.
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
``base_model = litellm_params.get("base_model") or model_name``
NOT provider-prefixed, matching chat's cost-map lookup convention.
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
``base_model`` from ``litellm_params`` or ``model_name``.
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
``llm_router_service`` are imported lazily inside the function body to
avoid hoisting litellm side-effects (``litellm.callbacks =
[token_tracker]``, ``litellm.drop_params``, etc.) into
``billable_calls.py``'s module load path.
"""
from sqlalchemy import select
from app.db import NewLLMConfig, SearchSpace
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if search_space is None:
raise ValueError(f"Search space {search_space_id} not found")
agent_llm_id = search_space.agent_llm_id
if agent_llm_id is None:
raise ValueError(
f"Search space {search_space_id} has no agent_llm_id configured"
)
owner_user_id: UUID = search_space.user_id
from app.services.auto_model_pin_service import (
AUTO_FASTEST_ID,
resolve_or_get_pinned_llm_config_id,
)
if agent_llm_id == AUTO_FASTEST_ID:
if thread_id is None:
return owner_user_id, "free", "auto"
try:
resolution = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=str(owner_user_id),
selected_llm_config_id=AUTO_FASTEST_ID,
)
except ValueError:
logger.warning(
"[agent_billing] Auto-mode pin resolution failed for "
"search_space=%s thread=%s; falling back to free",
search_space_id,
thread_id,
exc_info=True,
)
return owner_user_id, "free", "auto"
agent_llm_id = resolution.resolved_llm_config_id
if agent_llm_id < 0:
from app.services.llm_service import get_global_llm_config
cfg = get_global_llm_config(agent_llm_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
litellm_params = cfg.get("litellm_params") or {}
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
return owner_user_id, billing_tier, base_model
nlc_result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == agent_llm_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
nlc = nlc_result.scalars().first()
base_model = ""
if nlc is not None:
litellm_params = nlc.litellm_params or {}
base_model = litellm_params.get("base_model") or nlc.model_name or ""
return owner_user_id, "free", base_model
__all__ = [
"BillingSettlementError",
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",
]
# Re-export the config knob so callers don't have to import config just for
# the default image reserve.
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS

View file

@ -408,12 +408,37 @@ class ComposioService:
files = []
next_token = None
if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
# Try direct access first, then nested
files = data.get("files", []) or data.get("data", {}).get("files", [])
files = (
data.get("files", [])
or (
inner_data.get("files", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("files", [])
)
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
or data.get("data", {}).get("nextPageToken")
or (
inner_data.get("nextPageToken")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("next_page_token")
if isinstance(inner_data, dict)
else None
)
or response_data.get("nextPageToken")
or response_data.get("next_page_token")
)
elif isinstance(data, list):
files = data
@ -819,24 +844,61 @@ class ComposioService:
next_token = None
result_size_estimate = None
if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
messages = (
data.get("messages", [])
or data.get("data", {}).get("messages", [])
or (
inner_data.get("messages", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("messages", [])
or data.get("emails", [])
or (
inner_data.get("emails", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("emails", [])
)
# Check for pagination token in various possible locations
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
or data.get("data", {}).get("nextPageToken")
or data.get("data", {}).get("next_page_token")
or (
inner_data.get("nextPageToken")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("next_page_token")
if isinstance(inner_data, dict)
else None
)
or response_data.get("nextPageToken")
or response_data.get("next_page_token")
)
# Extract resultSizeEstimate if available (Gmail API provides this)
result_size_estimate = (
data.get("resultSizeEstimate")
or data.get("result_size_estimate")
or data.get("data", {}).get("resultSizeEstimate")
or data.get("data", {}).get("result_size_estimate")
or (
inner_data.get("resultSizeEstimate")
if isinstance(inner_data, dict)
else None
)
or (
inner_data.get("result_size_estimate")
if isinstance(inner_data, dict)
else None
)
or response_data.get("resultSizeEstimate")
or response_data.get("result_size_estimate")
)
elif isinstance(data, list):
messages = data
@ -864,7 +926,7 @@ class ComposioService:
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
params={"message_id": message_id}, # snake_case
entity_id=entity_id,
)
@ -872,7 +934,13 @@ class ComposioService:
if not result.get("success"):
return None, result.get("error", "Unknown error")
return result.get("data"), None
data = result.get("data")
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
return inner_data.get("response_data", inner_data), None
return data, None
except Exception as e:
logger.error(f"Failed to get Gmail message detail: {e!s}")
@ -928,10 +996,27 @@ class ComposioService:
# Try different possible response structures
events = []
if isinstance(data, dict):
inner_data = data.get("data", data)
response_data = (
inner_data.get("response_data", {})
if isinstance(inner_data, dict)
else {}
)
events = (
data.get("items", [])
or data.get("data", {}).get("items", [])
or (
inner_data.get("items", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("items", [])
or data.get("events", [])
or (
inner_data.get("events", [])
if isinstance(inner_data, dict)
else []
)
or response_data.get("events", [])
)
elif isinstance(data, list):
events = data

View file

@ -1,6 +1,8 @@
import asyncio
import os
import time
from datetime import datetime
from threading import Lock
from typing import Any
import httpx
@ -2769,12 +2771,22 @@ class ConnectorService:
"""
Get all available (enabled) connector types for a search space.
Phase 1.4: results are cached per ``search_space_id`` for
:data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
identity the cached value is plain data, safe to share across
requests. Invalidate on connector add/update/delete via
:func:`invalidate_connector_discovery_cache`.
Args:
search_space_id: The search space ID
Returns:
List of SearchSourceConnectorType enums for enabled connectors
"""
cached = _get_cached_connectors(search_space_id)
if cached is not None:
return list(cached)
query = (
select(SearchSourceConnector.connector_type)
.filter(
@ -2784,8 +2796,9 @@ class ConnectorService:
)
result = await self.session.execute(query)
connector_types = result.scalars().all()
return list(connector_types)
connector_types = list(result.scalars().all())
_set_cached_connectors(search_space_id, connector_types)
return connector_types
async def get_available_document_types(
self,
@ -2794,12 +2807,22 @@ class ConnectorService:
"""
Get all document types that have at least one document in the search space.
Phase 1.4: cached per ``search_space_id`` for
:data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
:func:`invalidate_connector_discovery_cache` when a connector
finishes indexing new documents (or document types are otherwise
added/removed).
Args:
search_space_id: The search space ID
Returns:
List of document type strings that have documents indexed
"""
cached = _get_cached_doc_types(search_space_id)
if cached is not None:
return list(cached)
from sqlalchemy import distinct
from app.db import Document
@ -2809,5 +2832,164 @@ class ConnectorService:
)
result = await self.session.execute(query)
doc_types = result.scalars().all()
return [str(dt) for dt in doc_types]
doc_types = [str(dt) for dt in result.scalars().all()]
_set_cached_doc_types(search_space_id, doc_types)
return doc_types
# ---------------------------------------------------------------------------
# Connector / document-type discovery TTL cache (Phase 1.4)
# ---------------------------------------------------------------------------
#
# Both ``get_available_connectors`` and ``get_available_document_types`` are
# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
# hits Postgres and contributes to per-turn agent build latency. Their
# results change infrequently — only when the user adds/edits/removes a
# connector, or when an indexer commits a new document type. A short TTL
# cache (default 30s, env-tunable) collapses N concurrent calls into one
# DB roundtrip with bounded staleness.
#
# Invalidation: connector mutation routes (create / update / delete) call
# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
# entry for the affected space. Multi-replica deployments still pay one
# DB roundtrip per replica per TTL window, which is fine — staleness is
# bounded and the alternative (cross-replica fanout) is not worth the
# coupling here.
_DISCOVERY_TTL_SECONDS: float = float(
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
)
# Per-search-space caches. Keyed by ``search_space_id``; value is
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
# read-mostly workload, sub-microsecond contention.
_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
_cache_lock = Lock()
def _get_cached_connectors(
search_space_id: int,
) -> list[SearchSourceConnectorType] | None:
if _DISCOVERY_TTL_SECONDS <= 0:
return None
with _cache_lock:
entry = _connectors_cache.get(search_space_id)
if entry is None:
return None
expires_at, payload = entry
if time.monotonic() >= expires_at:
_connectors_cache.pop(search_space_id, None)
return None
return payload
def _set_cached_connectors(
search_space_id: int, payload: list[SearchSourceConnectorType]
) -> None:
if _DISCOVERY_TTL_SECONDS <= 0:
return
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
with _cache_lock:
_connectors_cache[search_space_id] = (expires_at, list(payload))
def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
if _DISCOVERY_TTL_SECONDS <= 0:
return None
with _cache_lock:
entry = _doc_types_cache.get(search_space_id)
if entry is None:
return None
expires_at, payload = entry
if time.monotonic() >= expires_at:
_doc_types_cache.pop(search_space_id, None)
return None
return payload
def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
if _DISCOVERY_TTL_SECONDS <= 0:
return
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
with _cache_lock:
_doc_types_cache[search_space_id] = (expires_at, list(payload))
def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
"""Drop cached discovery results for ``search_space_id`` (or all spaces).
Connector CRUD routes / indexer pipelines call this when they mutate
the rows backing :func:`ConnectorService.get_available_connectors` /
:func:`get_available_document_types`. ``None`` clears every space
useful in tests and on bulk imports.
"""
with _cache_lock:
if search_space_id is None:
_connectors_cache.clear()
_doc_types_cache.clear()
else:
_connectors_cache.pop(search_space_id, None)
_doc_types_cache.pop(search_space_id, None)
def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
with _cache_lock:
if search_space_id is None:
_connectors_cache.clear()
else:
_connectors_cache.pop(search_space_id, None)
def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
with _cache_lock:
if search_space_id is None:
_doc_types_cache.clear()
else:
_doc_types_cache.pop(search_space_id, None)
def _register_invalidation_listeners() -> None:
"""Wire SQLAlchemy ORM events so cache stays consistent automatically.
Listening on ``after_insert`` / ``after_update`` / ``after_delete``
means every successful INSERT/UPDATE/DELETE that goes through the ORM
invalidates the affected search space's cached discovery payload —
no need to sprinkle ``invalidate_*`` calls across 30+ connector
routes. Bulk operations that bypass the ORM (e.g.
``session.execute(insert(...))`` without a mapped object) still need
explicit invalidation; document indexers already commit through the
ORM so document-type discovery is covered.
"""
from sqlalchemy import event
# Imported here (not at module top) to avoid a circular import:
# app.services.connector_service is itself imported from app.db's
# ecosystem indirectly via several CRUD modules.
from app.db import Document, SearchSourceConnector
def _connector_changed(_mapper, _connection, target) -> None:
sid = getattr(target, "search_space_id", None)
if sid is not None:
_invalidate_connectors_only(int(sid))
def _document_changed(_mapper, _connection, target) -> None:
sid = getattr(target, "search_space_id", None)
if sid is not None:
_invalidate_doc_types_only(int(sid))
for evt in ("after_insert", "after_update", "after_delete"):
event.listen(SearchSourceConnector, evt, _connector_changed)
event.listen(Document, evt, _document_changed)
try:
_register_invalidation_listeners()
except Exception: # pragma: no cover - defensive; never block module import
import logging as _logging
_logging.getLogger(__name__).exception(
"Failed to register connector discovery cache invalidation listeners; "
"stale cache risk: explicit invalidate_connector_discovery_cache calls "
"may be required."
)

View file

@ -17,7 +17,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.utils.google_credentials import build_composio_credentials
from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@ -78,14 +78,49 @@ class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if (
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
)
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
return cca_id
def _unwrap_composio_data(self, data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def _execute_composio_gmail_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
result = await ComposioService().execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return self._unwrap_composio_data(result.get("data")), None
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if self._is_composio_connector(connector):
raise ValueError(
"Composio Gmail connectors must use Composio tool execution"
)
config_data = dict(connector.config)
@ -139,6 +174,12 @@ class GmailToolMetadataService:
if not connector:
return True
if self._is_composio_connector(connector):
_profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
return bool(error)
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
@ -221,14 +262,21 @@ class GmailToolMetadataService:
)
connector = result.scalar_one_or_none()
if connector:
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
profile = await asyncio.get_event_loop().run_in_executor(
None,
lambda service=service: (
service.users().getProfile(userId="me").execute()
),
)
if self._is_composio_connector(connector):
profile, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
)
if error:
raise RuntimeError(error)
else:
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
profile = await asyncio.get_event_loop().run_in_executor(
None,
lambda service=service: (
service.users().getProfile(userId="me").execute()
),
)
acc_dict["email"] = profile.get("emailAddress", "")
except Exception:
logger.warning(
@ -298,6 +346,23 @@ class GmailToolMetadataService:
Returns ``None`` on any failure so callers can degrade gracefully.
"""
try:
if self._is_composio_connector(connector):
if not draft_id:
draft_id = await self._find_composio_draft_id(connector, message_id)
if not draft_id:
return None
draft, error = await self._execute_composio_gmail_tool(
connector,
"GMAIL_GET_DRAFT",
{"user_id": "me", "draft_id": draft_id, "format": "full"},
)
if error or not isinstance(draft, dict):
return None
payload = draft.get("message", {}).get("payload", {})
return self._extract_body_from_payload(payload)
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
@ -326,6 +391,33 @@ class GmailToolMetadataService:
)
return None
async def _find_composio_draft_id(
self, connector: SearchSourceConnector, message_id: str
) -> str | None:
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await self._execute_composio_gmail_tool(
connector, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list."""
try:

View file

@ -14,6 +14,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.composio_service import ComposioService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
@ -21,7 +22,6 @@ from app.utils.document_converters import (
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
logger.warning("Document %s not found in KB", document_id)
return {"status": "not_indexed"}
creds = await self._build_credentials_for_connector(connector_id)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
calendar_id = (document.document_metadata or {}).get(
"calendar_id"
) or "primary"
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event_id)
.execute()
),
)
connector = await self._get_connector(connector_id)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_EVENTS_GET",
params={"calendar_id": calendar_id, "event_id": event_id},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get("error", "Unknown Composio Calendar error")
)
live_event = composio_result.get("data", {})
if isinstance(live_event, dict):
live_event = live_event.get("data", live_event)
if isinstance(live_event, dict):
live_event = live_event.get("response_data", live_event)
else:
creds = await self._build_credentials_for_connector(connector_id)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event_id)
.execute()
),
)
event_summary = live_event.get("summary", "")
description = live_event.get("description", "")
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
result = await self.db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
connector = result.scalar_one_or_none()
if not connector:
raise ValueError(f"Connector {connector_id} not found")
return connector
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
connector = await self._get_connector(connector_id)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
raise ValueError("Composio connected_account_id not found")
raise ValueError(
"Composio Calendar connectors must use Composio tool execution"
)
config_data = dict(connector.config)

View file

@ -16,7 +16,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.utils.google_credentials import build_composio_credentials
from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if (
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
)
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
return cca_id
async def _execute_composio_calendar_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict,
) -> tuple[dict | list | None, str | None]:
service = ComposioService()
result = await service.execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Calendar error")
data = result.get("data")
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner), None
return inner, None
return data, None
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if self._is_composio_connector(connector):
raise ValueError(
"Composio Calendar connectors must use Composio tool execution"
)
config_data = dict(connector.config)
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
if not connector:
return True
if self._is_composio_connector(connector):
_data, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_GET_CALENDAR",
{"calendar_id": "primary"},
)
return bool(error)
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
timezone_str = ""
if connector:
try:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
if self._is_composio_connector(connector):
cal_list, cal_error = await self._execute_composio_calendar_tool(
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
)
if cal_error:
raise RuntimeError(cal_error)
(
settings,
settings_error,
) = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_SETTINGS_GET",
{"setting": "timezone"},
)
if not settings_error and isinstance(settings, dict):
timezone_str = settings.get("value", "")
else:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
cal_list = await loop.run_in_executor(
None, lambda: service.calendarList().list().execute()
)
for cal in cal_list.get("items", []):
cal_list = await loop.run_in_executor(
None, lambda: service.calendarList().list().execute()
)
tz_setting = await loop.run_in_executor(
None,
lambda: service.settings().get(setting="timezone").execute(),
)
timezone_str = tz_setting.get("value", "")
calendar_items = []
if isinstance(cal_list, dict):
calendar_items = (
cal_list.get("items") or cal_list.get("calendars") or []
)
elif isinstance(cal_list, list):
calendar_items = cal_list
for cal in calendar_items:
calendars.append(
{
"id": cal.get("id", ""),
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
"primary": cal.get("primary", False),
}
)
tz_setting = await loop.run_in_executor(
None,
lambda: service.settings().get(setting="timezone").execute(),
)
timezone_str = tz_setting.get("value", "")
except Exception:
logger.warning(
"Failed to fetch calendars/timezone for connector %s",
@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
event_dict = event.to_dict()
try:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
calendar_id = event.calendar_id or "primary"
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event.event_id)
.execute()
),
)
if self._is_composio_connector(connector):
live_event, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_EVENTS_GET",
{"calendar_id": calendar_id, "event_id": event.event_id},
)
if error:
raise RuntimeError(error)
else:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event.event_id)
.execute()
),
)
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
event_dict["description"] = live_event.get(
@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved:
live_resolved = await self._resolve_live_event(
search_space_id, user_id, event_ref
)
if not live_resolved:
return {
"error": (
f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
"This could mean: (1) the event doesn't exist, "
"(2) the event name is different, or "
"(3) the connected calendar account cannot access it."
)
}
connector, live_event = live_resolved
account = GoogleCalendarAccount.from_connector(connector)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"error": (
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the event name is different."
)
"account": acc_dict,
"event": self._event_dict_from_live_event(live_event),
}
document, connector = resolved
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
if row:
return row[0], row[1]
return None
async def _resolve_live_event(
self, search_space_id: int, user_id: str, event_ref: str
) -> tuple[SearchSourceConnector, dict] | None:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
for connector in connectors:
try:
events = await self._search_live_events(connector, event_ref)
except Exception:
logger.warning(
"Failed to search live calendar events for connector %s",
connector.id,
exc_info=True,
)
continue
if not events:
continue
normalized_ref = event_ref.strip().lower()
exact_match = next(
(
event
for event in events
if event.get("summary", "").strip().lower() == normalized_ref
),
None,
)
return connector, exact_match or events[0]
return None
async def _search_live_events(
self, connector: SearchSourceConnector, event_ref: str
) -> list[dict]:
if self._is_composio_connector(connector):
data, error = await self._execute_composio_calendar_tool(
connector,
"GOOGLECALENDAR_EVENTS_LIST",
{
"calendar_id": "primary",
"q": event_ref,
"max_results": 10,
"single_events": True,
"order_by": "startTime",
},
)
if error:
raise RuntimeError(error)
if isinstance(data, dict):
return data.get("items") or data.get("events") or []
return data if isinstance(data, list) else []
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
response = await loop.run_in_executor(
None,
lambda: (
service.events()
.list(
calendarId="primary",
q=event_ref,
maxResults=10,
singleEvents=True,
orderBy="startTime",
)
.execute()
),
)
return response.get("items", [])
def _event_dict_from_live_event(self, event: dict) -> dict:
start_data = event.get("start", {})
end_data = event.get("end", {})
return {
"event_id": event.get("id", ""),
"summary": event.get("summary", "No Title"),
"start": start_data.get("dateTime", start_data.get("date", "")),
"end": end_data.get("dateTime", end_data.get("date", "")),
"description": event.get("description", ""),
"location": event.get("location", ""),
"attendees": [
{
"email": attendee.get("email", ""),
"responseStatus": attendee.get("responseStatus", ""),
}
for attendee in event.get("attendees", [])
],
"calendar_id": event.get("calendarId", "primary"),
"document_id": None,
"indexed_at": None,
}

View file

@ -13,7 +13,7 @@ from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.utils.google_credentials import build_composio_credentials
from app.services.composio_service import ComposioService
logger = logging.getLogger(__name__)
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
return (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
def _get_composio_connected_account_id(
self, connector: SearchSourceConnector
) -> str:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected_account_id not found")
return cca_id
async def _execute_composio_drive_tool(
self,
connector: SearchSourceConnector,
tool_name: str,
params: dict,
) -> tuple[dict | list | None, str | None]:
result = await ComposioService().execute_tool(
connected_account_id=self._get_composio_connected_account_id(connector),
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{connector.user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Drive error")
data = result.get("data")
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner), None
return inner, None
return data, None
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
if not connector:
return True
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
if self._is_composio_connector(connector):
_data, error = await self._execute_composio_drive_tool(
connector,
"GOOGLEDRIVE_LIST_FILES",
{
"q": "trashed = false",
"page_size": 1,
"fields": "files(id)",
},
)
return bool(error)
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
credentials=pre_built_creds,
)
await client.list_files(
query="trashed = false", page_size=1, fields="files(id)"
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
parent_folders[connector_id] = []
continue
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
if self._is_composio_connector(connector):
data, error = await self._execute_composio_drive_tool(
connector,
"GOOGLEDRIVE_LIST_FILES",
{
"q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
"fields": "files(id,name)",
"page_size": 50,
},
)
if error:
logger.warning(
"Failed to list folders for connector %s: %s",
connector_id,
error,
)
parent_folders[connector_id] = []
continue
folders = []
if isinstance(data, dict):
folders = data.get("files", [])
elif isinstance(data, list):
folders = data
parent_folders[connector_id] = [
{"folder_id": f["id"], "name": f["name"]}
for f in folders
if f.get("id") and f.get("name")
]
continue
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
credentials=pre_built_creds,
)
folders, _, error = await client.list_files(

View file

@ -20,6 +20,8 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
@ -152,12 +154,12 @@ class ImageGenRouterService:
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
else:
provider = config.get("provider", "").upper()
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
@ -165,9 +167,16 @@ class ImageGenRouterService:
"api_key": config.get("api_key"),
}
# Add optional api_base
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
# Resolve ``api_base`` so deployments don't silently inherit
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
# the wrong provider (see ``provider_api_base`` docstring).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):

View file

@ -134,42 +134,14 @@ PROVIDER_MAP = {
}
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
# request to an Azure endpoint, which then 404s with ``Resource not found``.
# Only providers with a well-known, stable public base URL are listed here —
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
# so their existing config-driven behaviour is preserved.
PROVIDER_DEFAULT_API_BASE = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
# Canonical provider → base URL when a config uses a generic ``openai``-style
# prefix but the ``provider`` field tells us which API it really is
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
# each has its own base URL).
PROVIDER_KEY_DEFAULT_API_BASE = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
# call sites can share the exact same defense (OpenRouter / Groq / etc.
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
resolve_api_base,
)
class LLMRouterService:
@ -466,14 +438,14 @@ class LLMRouterService:
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
# requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = config.get("api_base")
if not api_base:
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
if not api_base:
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base

View file

@ -16,6 +16,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@ -496,8 +497,14 @@ async def get_vision_llm(
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
@ -519,6 +526,8 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@ -526,6 +535,13 @@ async def get_vision_llm(
)
return None
try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
@ -541,29 +557,46 @@ async def get_vision_llm(
return None
if global_cfg.get("custom_provider"):
model_string = (
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
)
provider_prefix = global_cfg["custom_provider"]
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
prefix = VISION_PROVIDER_MAP.get(
provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
)
model_string = f"{prefix}/{global_cfg['model_name']}"
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
if global_cfg.get("api_base"):
litellm_kwargs["api_base"] = global_cfg["api_base"]
api_base = resolve_api_base(
provider=global_cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=global_cfg.get("api_base"),
)
if api_base:
litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
return SanitizedChatLiteLLM(**litellm_kwargs)
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
user_id=owner_user_id,
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
@ -578,20 +611,26 @@ async def get_vision_llm(
return None
if vision_cfg.custom_provider:
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
provider_prefix = vision_cfg.custom_provider
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
prefix = VISION_PROVIDER_MAP.get(
provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
model_string = f"{prefix}/{vision_cfg.model_name}"
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
if vision_cfg.api_base:
litellm_kwargs["api_base"] = vision_cfg.api_base
api_base = resolve_api_base(
provider=vision_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=vision_cfg.api_base,
)
if api_base:
litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)

View file

@ -93,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool:
return output_mods == ["text"]
def _is_image_output_model(model: dict) -> bool:
"""Return True if the model can produce image output.
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
``["image"]`` for pure image generators, ``["text", "image"]`` for
multi-modal generators that also emit captions). We accept any model
that can output images; the call site decides whether to use the
image-generation API or chat completion.
"""
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def _is_vision_input_model(model: dict) -> bool:
"""Return True if the model can ingest an image AND emit text.
OpenRouter's ``architecture.input_modalities`` lists what the model
accepts; ``output_modalities`` lists what it produces. A vision LLM
is a model that takes images in and produces text out i.e. it can
answer questions about a screenshot or extract content from an
image. Pure image-to-image models (e.g. style transfer) and
text-only models are excluded.
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
output_mods = arch.get("output_modalities", []) or []
return "image" in input_mods and "text" in output_mods
def _supports_image_input(model: dict) -> bool:
"""Return True if the model accepts ``image`` in its input modalities.
Differs from :func:`_is_vision_input_model` in that it does NOT
require text output chat-tab models always emit text already (the
chat catalog filters by ``_is_text_output_model``), so the only
extra capability we need to track per chat config is whether the
model can ingest user-attached images. The chat selector and the
streaming task both key off this flag to prevent hitting an
OpenRouter 404 ``"No endpoints found that support image input"``
when the user uploads an image and selects a text-only model
(DeepSeek V3, Llama 3.x base, etc.).
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
return "image" in input_mods
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@ -175,6 +222,32 @@ async def _fetch_models_async() -> list[dict] | None:
return None
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
Pricing values are kept as the raw OpenRouter strings (e.g.
``"0.000003"``); ``pricing_registration`` converts them to floats
when registering with LiteLLM. Models with missing or malformed
pricing are simply omitted operator-side risk if any of those are
premium.
"""
pricing: dict[str, dict[str, str]] = {}
for model in raw_models:
model_id = str(model.get("id") or "").strip()
if not model_id:
continue
p = model.get("pricing") or {}
prompt = p.get("prompt")
completion = p.get("completion")
if prompt is None and completion is None:
continue
pricing[model_id] = {
"prompt": str(prompt) if prompt is not None else "",
"completion": str(completion) if completion is not None else "",
}
return pricing
def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
@ -266,6 +339,13 @@ def _generate_configs(
# account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium",
# Capability flag derived from ``architecture.input_modalities``.
# Read by the new-chat selector to dim image-incompatible models
# when the user has pending image attachments, and by
# ``stream_new_chat`` as a fail-fast safety net before the
# OpenRouter request would otherwise 404 with
# ``"No endpoints found that support image input"``.
"supports_image_input": _supports_image_input(model),
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
@ -282,6 +362,171 @@ def _generate_configs(
return configs
# ID-offset bands used to keep dynamic OpenRouter configs in their own
# namespace per surface. Image / vision get separate bands so a single
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
def _generate_image_gen_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter image-generation models into global image-gen
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
Filter:
- architecture.output_modalities contains "image"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
``_has_sufficient_context``) because tool calls and context windows are
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
derived per model the same way as chat (``_openrouter_tier``).
Cost is intentionally *not* registered with LiteLLM at startup
(``pricing_registration`` skips image gen): OpenRouter image-gen
models are not in LiteLLM's native cost map and OpenRouter populates
``response_cost`` directly from the response header. A defensive
branch in ``_extract_cost_usd`` handles the rare case where
``usage.cost`` is missing see ``token_tracking_service``.
"""
id_offset: int = int(
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
free_rpm: int = settings.get("free_rpm", 20)
litellm_params: dict = settings.get("litellm_params") or {}
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in image_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
# ``image_generation/transformation`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
def _generate_vision_llm_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
Filter:
- architecture.input_modalities contains "image"
- architecture.output_modalities contains "text"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Vision-LLM is invoked from the indexer (image extraction during
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
filters do not apply: a small-context vision model that doesn't
advertise tool-calling is still perfectly viable for "describe this
image" prompts.
"""
id_offset: int = int(
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
litellm_params: dict = settings.get("litellm_params") or {}
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in vision_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
pricing = model.get("pricing") or {}
# Capture per-token prices so ``pricing_registration`` can
# register them with LiteLLM at startup (and so the cost
# estimator in ``estimate_call_reserve_micros`` can resolve
# them at reserve time).
try:
input_cost = float(pricing.get("prompt", 0) or 0)
except (TypeError, ValueError):
input_cost = 0.0
try:
output_cost = float(pricing.get("completion", 0) or 0)
except (TypeError, ValueError):
output_cost = 0.0
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
"quota_reserve_tokens": quota_reserve_tokens,
"input_cost_per_token": input_cost or None,
"output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@ -300,6 +545,19 @@ class OpenRouterIntegrationService:
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
# Raw OpenRouter pricing per model_id, captured at the same time
# we generate configs. Consumed by ``pricing_registration`` to
# teach LiteLLM the per-token cost of every dynamic deployment so
# the success-callback can populate ``response_cost`` correctly.
self._raw_pricing: dict[str, dict[str, str]] = {}
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -329,8 +587,32 @@ class OpenRouterIntegrationService:
self._initialized = True
return []
self._raw_models = raw_models
self._configs = _generate_configs(raw_models, settings)
self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
self._image_configs = _generate_image_gen_configs(raw_models, settings)
logger.info(
"OpenRouter integration: image-gen emission ON (%d models)",
len(self._image_configs),
)
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@ -369,6 +651,8 @@ class OpenRouterIntegrationService:
new_configs = _generate_configs(raw_models, self._settings)
new_by_id = {c["id"]: c for c in new_configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
self._raw_models = raw_models
from app.config import config as app_config
@ -382,6 +666,29 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
c
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@ -407,6 +714,21 @@ class OpenRouterIntegrationService:
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Re-register LiteLLM pricing for the freshly fetched catalogue
# so newly added OR models bill correctly on their first call.
# Runs before the router rebuild because the router may issue
# cost-table lookups during deployment registration.
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as exc:
logger.warning(
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
)
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
@ -635,3 +957,34 @@ class OpenRouterIntegrationService:
def get_config_by_id(self, config_id: int) -> dict | None:
return self._configs_by_id.get(config_id)
def get_image_generation_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter image-generation configs (empty
list when the ``image_generation_enabled`` flag is off).
Each entry already has ``billing_tier`` derived per-model from
OpenRouter's signals and is shaped to drop directly into
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
values are the strings OpenRouter publishes (USD per token),
never converted to floats here so the caller can decide how to
handle malformed or unset entries.
"""
return dict(self._raw_pricing)

View file

@ -0,0 +1,274 @@
"""
Pricing registration with LiteLLM.
Many models reach our LiteLLM callback without LiteLLM knowing their
per-token cost namely:
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
pricing table).
* Static YAML deployments whose ``base_model`` name is operator-defined
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
not in LiteLLM's table either.
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
and the user gets billed nothing a fail-safe but wrong answer for a
cost-based credit system. This module runs once at startup, after the
OpenRouter integration has fetched its catalogue, and registers each
known model's pricing with ``litellm.register_model()`` under multiple
plausible alias keys (LiteLLM's cost lookup may use any of them
depending on whether the call went through the Router, ChatLiteLLM,
or a direct ``acompletion``).
Operators who run a custom Azure deployment whose ``base_model`` name
isn't in LiteLLM's table can declare per-token pricing inline in
``global_llm_config.yaml`` via ``input_cost_per_token`` and
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
that declaration the model's calls debit 0 — never overbilled.
"""
from __future__ import annotations
import logging
from typing import Any
import litellm
logger = logging.getLogger(__name__)
def _safe_float(value: Any) -> float:
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
if value is None:
return 0.0
try:
f = float(value)
except (TypeError, ValueError):
return 0.0
return f if f > 0 else 0.0
def _alias_set_for_openrouter(model_id: str) -> list[str]:
"""Return the alias keys to register an OpenRouter model under.
LiteLLM's cost-callback lookup key varies by call path:
- Router with ``model="openrouter/X"`` kwargs["model"] is
typically ``openrouter/X``.
- LiteLLM's own provider routing may strip the prefix and pass the
bare ``X`` to the cost-table lookup.
Registering under both keeps the lookup hermetic regardless of
which path the call took.
"""
aliases = [f"openrouter/{model_id}", model_id]
return list(dict.fromkeys(a for a in aliases if a))
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
"""Return the alias keys to register a static YAML deployment under.
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
bare ``model_name`` because callbacks sometimes see whichever was
configured first.
"""
provider_lower = (provider or "").lower()
aliases: list[str] = []
if base_model:
aliases.append(base_model)
if provider_lower and base_model:
aliases.append(f"{provider_lower}/{base_model}")
if model_name and model_name != base_model:
aliases.append(model_name)
if provider_lower and model_name and model_name != base_model:
aliases.append(f"{provider_lower}/{model_name}")
# Azure deployments often surface as "azure/<name>"; normalise the
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
if provider_lower == "azure_openai":
if base_model:
aliases.append(f"azure/{base_model}")
if model_name and model_name != base_model:
aliases.append(f"azure/{model_name}")
return list(dict.fromkeys(a for a in aliases if a))
def _register(
aliases: list[str],
*,
input_cost: float,
output_cost: float,
provider: str,
mode: str = "chat",
) -> int:
"""Register a single pricing entry under every alias in ``aliases``.
Returns the count of aliases successfully registered.
"""
payload: dict[str, dict[str, Any]] = {}
for alias in aliases:
payload[alias] = {
"input_cost_per_token": input_cost,
"output_cost_per_token": output_cost,
"litellm_provider": provider,
"mode": mode,
}
if not payload:
return 0
try:
litellm.register_model(payload)
except Exception as exc:
logger.warning(
"[PricingRegistration] register_model failed for aliases=%s: %s",
aliases,
exc,
)
return 0
return len(payload)
def _register_chat_shape_configs(
configs: list[dict],
*,
or_pricing: dict[str, dict[str, str]],
label: str,
) -> tuple[int, int, int, list[str]]:
"""Common loop that registers per-token pricing for a list of "chat-shape"
configs (chat or vision LLM both use ``input_cost_per_token`` /
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
"""
registered_models = 0
registered_aliases = 0
skipped_no_pricing = 0
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("provider") or "").upper()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
if provider == "OPENROUTER":
entry = or_pricing.get(model_name)
if entry:
input_cost = _safe_float(entry.get("prompt"))
output_cost = _safe_float(entry.get("completion"))
else:
# Vision configs from ``_generate_vision_llm_configs``
# carry their pricing inline because the OpenRouter
# raw-pricing cache is keyed by chat-catalogue model_id;
# vision flows pick up the inline values here.
input_cost = _safe_float(cfg.get("input_cost_per_token"))
output_cost = _safe_float(cfg.get("output_cost_per_token"))
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_openrouter(model_name)
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider="openrouter",
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
continue
input_cost = _safe_float(
cfg.get("input_cost_per_token")
or litellm_params.get("input_cost_per_token")
)
output_cost = _safe_float(
cfg.get("output_cost_per_token")
or litellm_params.get("output_cost_per_token")
)
if input_cost == 0.0 and output_cost == 0.0:
skipped_no_pricing += 1
continue
aliases = _alias_set_for_yaml(provider, model_name, base_model)
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider=provider_slug,
)
if count > 0:
registered_models += 1
registered_aliases += count
if len(sample_keys) < 6:
sample_keys.extend(aliases[:2])
logger.info(
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
"%d configs had no pricing data; sample registered keys=%s",
label,
registered_models,
registered_aliases,
skipped_no_pricing,
sample_keys,
)
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
startup fetch) and converts the per-token strings to floats. For
vision configs that carry pricing inline (``input_cost_per_token`` /
``output_cost_per_token`` set on the cfg itself) we fall back to
those values when the OR cache misses the model.
2. Anything else: looks for operator-declared
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
config block (top-level or nested under ``litellm_params``).
**Image generation is intentionally NOT registered here.** The cost
shape for image-gen is per-image (``output_cost_per_image``), not
per-token, and LiteLLM's ``register_model`` doesn't accept those
keys via the chat-cost path. OpenRouter image-gen models populate
``response_cost`` directly from their response header instead, and
Azure-native image-gen models are already in LiteLLM's cost map.
Calls without a resolved pair of costs are skipped, not registered
with zeros operators who forget pricing get a "$0 debit" warning
in ``TokenTrackingCallback`` rather than silently overwriting any
pricing LiteLLM might know natively.
"""
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
logger.info("[PricingRegistration] no global configs to register")
return
or_pricing: dict[str, dict[str, str]] = {}
try:
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
if OpenRouterIntegrationService.is_initialized():
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
except Exception as exc:
logger.debug(
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
)
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -0,0 +1,106 @@
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
LiteLLM falls back to the module-global ``litellm.api_base`` when an
individual call doesn't pass one, which silently inherits provider-agnostic
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
``/chat/completions`` to whatever inherited base it gets, regardless of
provider).
The chat router has had this defense for a while
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
into a tiny standalone helper so vision and image-gen can share the same
source of truth without an inter-service circular import.
"""
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
Only providers with a well-known, stable public base URL are listed
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
huggingface, databricks, cloudflare, replicate) are intentionally omitted
so their existing config-driven behaviour is preserved."""
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
"""Canonical provider key (uppercase) → base URL.
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
config's ``provider`` field tells us which API it actually is (DeepSeek,
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
has its own base URL)."""
def resolve_api_base(
*,
provider: str | None,
provider_prefix: str | None,
config_api_base: str | None,
) -> str | None:
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
Cascade (first non-empty wins):
1. The config's own ``api_base`` (whitespace-only treated as missing).
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
4. ``None`` caller should NOT set ``api_base`` and let the LiteLLM
provider integration apply its own default (e.g. AzureOpenAI's
deployment-derived URL, custom provider's per-deployment URL).
Args:
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
``"DEEPSEEK"``). Case-insensitive.
provider_prefix: The LiteLLM model-string prefix the same call
site builds for the model id (e.g. ``"openrouter"``,
``"groq"``). Case-insensitive.
config_api_base: ``api_base`` from the global YAML / DB row /
OpenRouter dynamic config. Empty / whitespace-only means
"missing" the resolver still applies the cascade.
Returns:
A URL string, or ``None`` if no default applies for this provider.
"""
if config_api_base and config_api_base.strip():
return config_api_base
if provider:
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
if key_default:
return key_default
if provider_prefix:
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
if prefix_default:
return prefix_default
return None
__all__ = [
"PROVIDER_DEFAULT_API_BASE",
"PROVIDER_KEY_DEFAULT_API_BASE",
"resolve_api_base",
]

View file

@ -0,0 +1,280 @@
"""Capability resolution shared by chat / image / vision call sites.
Why this exists
---------------
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
single, authoritative answer to one question: *can this chat config accept
``image_url`` content blocks?* Without it, the new-chat selector can't badge
incompatible models and the streaming task can't fail fast with a friendly
error before sending an image to a text-only provider.
Two functions, two intents:
- :func:`derive_supports_image_input` best-effort *True* for catalog and
UI surfacing. Default-allow: an unknown / unmapped model is treated as
capable so we never lock the user out of a freshly added or
third-party-hosted vision model.
- :func:`is_known_text_only_chat_model` strict opt-out for the streaming
task's safety net. Returns True only when LiteLLM's model map *explicitly*
sets ``supports_vision=False`` (or its bare-name variant does). Anything
else missing key, lookup exception, ``supports_vision=True`` returns
False so the request flows through to the provider.
Implementation rule: only public LiteLLM symbols
------------------------------------------------
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
across releases. The private ``_is_explicitly_disabled_factory`` and
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
can't silently break us.
Why the previous round's strict YAML opt-in flag failed
-------------------------------------------------------
``supports_image_input: false`` was the YAML loader's setdefault. Operators
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
YAML chat model including vision-capable GPT-5.x and GPT-4o resolved to
False and the streaming gate rejected every image turn. Sourcing capability
from LiteLLM's authoritative model map (which already says
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
import litellm
logger = logging.getLogger(__name__)
# Provider-name → LiteLLM model-prefix map.
#
# Owned here because ``app.services.provider_capabilities`` is the
# only edge that's safe to call from ``app.config``'s YAML loader at
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
# this constant under the historical ``PROVIDER_MAP`` name; placing the
# map there directly would re-introduce the
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
# app.config`` cycle that prompted the move.
_PROVIDER_PREFIX_MAP: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _candidate_model_strings(
*,
provider: str | None,
model_name: str | None,
base_model: str | None,
custom_provider: str | None,
) -> list[tuple[str, str | None]]:
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
LiteLLM's capability lookup is keyed by ``model`` + (optional)
``custom_llm_provider``. Different config sources give us different
levels of detail, so we try the most-specific keys first and fall back
to bare model names so unannotated entries (e.g. an Azure deployment
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
map. Order matters the first lookup that returns a definitive answer
wins for both helpers.
"""
candidates: list[tuple[str, str | None]] = []
seen: set[tuple[str, str | None]] = set()
def _add(model: str | None, llm_provider: str | None) -> None:
if not model:
return
key = (model, llm_provider)
if key in seen:
return
seen.add(key)
candidates.append(key)
provider_prefix: str | None = None
if provider:
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
if custom_provider:
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
provider_prefix = custom_provider
primary_model = base_model or model_name
bare_model = model_name
# Most-specific first: provider-prefixed identifier with explicit
# custom_llm_provider so LiteLLM won't have to guess the provider via
# ``get_llm_provider``.
if primary_model and provider_prefix:
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
if "/" in primary_model:
_add(primary_model, provider_prefix)
else:
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
# Bare base_model (or model_name) with provider hint — handles entries
# the upstream map keys without a provider prefix (most ``gpt-*`` and
# ``claude-*`` entries do this).
if primary_model:
_add(primary_model, provider_prefix)
# Fallback to model_name when base_model differs (e.g. an Azure
# deployment whose model_name is the deployment id but base_model is the
# canonical OpenAI sku).
if bare_model and bare_model != primary_model:
if provider_prefix and "/" not in bare_model:
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
_add(bare_model, provider_prefix)
_add(bare_model, None)
return candidates
def derive_supports_image_input(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
openrouter_input_modalities: Iterable[str] | None = None,
) -> bool:
"""Best-effort capability flag for the new-chat selector and catalog.
Resolution order (first definitive answer wins):
1. ``openrouter_input_modalities`` (when provided as a non-empty
iterable). OpenRouter exposes ``architecture.input_modalities`` per
model and that's the authoritative source for OR dynamic configs.
2. ``litellm.supports_vision`` against each candidate identifier from
:func:`_candidate_model_strings`. Returns True as soon as any
candidate confirms vision support.
3. Default ``True`` the conservative-allow stance. An unknown /
newly-added / third-party-hosted model is *not* pre-judged. The
streaming safety net (:func:`is_known_text_only_chat_model`) is the
only place a False ever blocks; everywhere else, a False here would
just hide a usable model from the user.
Returns:
True if the model can plausibly accept image input, False only when
OpenRouter explicitly says it can't.
"""
if openrouter_input_modalities is not None:
modalities = list(openrouter_input_modalities)
if modalities:
return "image" in modalities
# Empty list explicitly published by OR — treat as "no image".
return False
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
if litellm.supports_vision(
model=model_string, custom_llm_provider=custom_llm_provider
):
return True
except Exception as exc:
logger.debug(
"litellm.supports_vision raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
return True
def is_known_text_only_chat_model(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
) -> bool:
"""Strict opt-out probe for the streaming-task safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False`` for at least one candidate identifier. Missing
key, lookup exception, or ``supports_vision=True`` all return False so
the streaming task lets the request through. This is the inverse-default
of :func:`derive_supports_image_input`.
Why two functions
-----------------
The selector wants "show me everything that's plausibly capable"
default-allow. The safety net wants "block only when I'm certain it
can't" — default-pass. Mixing the two intents in a single function
leads to the regression we're fixing here.
"""
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
info = litellm.get_model_info(
model=model_string, custom_llm_provider=custom_llm_provider
)
except Exception as exc:
logger.debug(
"litellm.get_model_info raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
# may be missing, None, True, or False. We only fire on explicit
# False — None / missing / True all mean "don't block".
try:
value = info.get("supports_vision") # type: ignore[union-attr]
except AttributeError:
value = None
if value is False:
return True
return False
__all__ = [
"derive_supports_image_input",
"is_known_text_only_chat_model",
]

View file

@ -0,0 +1,105 @@
"""
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
indexing pipeline (file processors, connector indexers, etl pipeline) can
keep invoking the LLM exactly the way they do today ``await llm.ainvoke(...)``
without threading ``user_id`` through every parser. The wrapper looks like
a chat model from the outside; on the inside it routes each call through
``billable_call`` so the user's premium credit pool is reserved → finalized
or released, and a ``TokenUsage`` audit row is written.
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
need quota enforcement) so this class only ever wraps premium configs.
Why a wrapper instead of plumbing ``user_id`` through every caller:
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
Dropbox, local-folder, file-processor, ETL pipeline) each calling
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
invasive, error-prone, and easy for a future indexer to forget.
* Per the design (issue M), we always debit the *search-space owner*, not
the triggering user, so ``user_id`` is fully derivable from the search
space the caller is already operating on. The wrapper captures it once
at construction time.
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
call run this coroutine"; subclassing isn't safe across versions because
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
Composition via attribute proxying (``__getattr__``) is robust to
upstream changes every method other than ``ainvoke`` falls through to
the inner LLM unchanged.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from app.services.billable_calls import QuotaInsufficientError, billable_call
logger = logging.getLogger(__name__)
class QuotaCheckedVisionLLM:
"""Composition wrapper around a langchain chat model that enforces
premium credit quota on every ``ainvoke``.
Anything other than ``ainvoke`` is forwarded to the inner model so
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
still work they simply bypass quota enforcement, which is fine
because the indexing pipeline only ever calls ``ainvoke`` today.
"""
def __init__(
self,
inner_llm: Any,
*,
user_id: UUID,
search_space_id: int,
billing_tier: str,
base_model: str,
quota_reserve_tokens: int | None,
usage_type: str = "vision_extraction",
) -> None:
self._inner = inner_llm
self._user_id = user_id
self._search_space_id = search_space_id
self._billing_tier = billing_tier
self._base_model = base_model
self._quota_reserve_tokens = quota_reserve_tokens
self._usage_type = usage_type
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
"""Proxied async invoke that runs the underlying call inside
``billable_call``.
Raises:
QuotaInsufficientError: when the user has exhausted their
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
catches this and falls back to the document parser.
"""
async with billable_call(
user_id=self._user_id,
search_space_id=self._search_space_id,
billing_tier=self._billing_tier,
base_model=self._base_model,
quota_reserve_tokens=self._quota_reserve_tokens,
usage_type=self._usage_type,
call_details={"model": self._base_model},
):
return await self._inner.ainvoke(input, *args, **kwargs)
def __getattr__(self, name: str) -> Any:
"""Forward everything else (``invoke``, ``astream``, ``bind``,
``with_structured_output``, ) to the inner model.
``__getattr__`` is only consulted when the attribute is *not*
already found on the proxy, which is exactly the contract we
want methods we override stay on the proxy, the rest fall
through.
"""
return getattr(self._inner, name)
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]

View file

@ -22,6 +22,71 @@ from app.config import config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-call reservation estimator (USD micro-units)
# ---------------------------------------------------------------------------
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
# request, and so models without registered pricing reserve at least
# something while the call runs (debited 0 at finalize anyway when their
# cost can't be resolved).
_QUOTA_MIN_RESERVE_MICROS = 100
def estimate_call_reserve_micros(
*,
base_model: str,
quota_reserve_tokens: int | None,
) -> int:
"""Return the number of micro-USD to reserve for one premium call.
Computes a worst-case upper bound from LiteLLM's per-token pricing
table:
reserve_usd reserve_tokens x (input_cost + output_cost)
so the math scales with model cost Claude Opus + 4K reserve_tokens
naturally reserves $0.36, while a cheap model reserves only a few
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
so a misconfigured "$1000/M" model can't lock the whole balance on
one call.
If ``litellm.get_model_info`` raises (model unknown) we fall back to
the floor 100 micros / $0.0001 which is enough to gate a sane
request without over-reserving for a model whose pricing the
operator hasn't declared yet.
"""
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
if reserve_tokens <= 0:
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
try:
from litellm import get_model_info
info = get_model_info(base_model) if base_model else {}
input_cost = float(info.get("input_cost_per_token") or 0.0)
output_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception as exc:
logger.debug(
"[quota_reserve] cost lookup failed for base_model=%s: %s",
base_model,
exc,
)
input_cost = 0.0
output_cost = 0.0
if input_cost == 0.0 and output_cost == 0.0:
return _QUOTA_MIN_RESERVE_MICROS
reserve_usd = reserve_tokens * (input_cost + output_cost)
reserve_micros = round(reserve_usd * 1_000_000)
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
return reserve_micros
class QuotaScope(StrEnum):
ANONYMOUS = "anonymous"
PREMIUM = "premium"
@ -444,8 +509,16 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
reserve_tokens: int,
reserve_micros: int,
) -> QuotaResult:
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
premium credit balance.
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
all in micro-USD on this code path; callers (chat stream,
token-status route, FE display) convert to dollars by dividing
by 1_000_000.
"""
from app.db import User
user = (
@ -465,11 +538,11 @@ class TokenQuotaService:
limit=0,
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
effective = used + reserved + reserve_tokens
effective = used + reserved + reserve_micros
if effective > limit:
remaining = max(0, limit - used - reserved)
await db_session.rollback()
@ -482,10 +555,10 @@ class TokenQuotaService:
remaining=remaining,
)
user.premium_tokens_reserved = reserved + reserve_tokens
user.premium_credit_micros_reserved = reserved + reserve_micros
await db_session.commit()
new_reserved = reserved + reserve_tokens
new_reserved = reserved + reserve_micros
remaining = max(0, limit - used - new_reserved)
warning_threshold = int(limit * 0.8)
@ -510,9 +583,12 @@ class TokenQuotaService:
db_session: AsyncSession,
user_id: Any,
request_id: str,
actual_tokens: int,
reserved_tokens: int,
actual_micros: int,
reserved_micros: int,
) -> QuotaResult:
"""Settle the reservation: release ``reserved_micros`` and debit
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
"""
from app.db import User
user = (
@ -529,16 +605,18 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
user.premium_credit_micros_used = (
user.premium_credit_micros_used + actual_micros
)
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
await db_session.commit()
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)
@ -562,8 +640,13 @@ class TokenQuotaService:
async def premium_release(
db_session: AsyncSession,
user_id: Any,
reserved_tokens: int,
reserved_micros: int,
) -> None:
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
Used when a request fails before finalize (so the reservation
doesn't leak credit).
"""
from app.db import User
user = (
@ -576,8 +659,8 @@ class TokenQuotaService:
.scalar_one_or_none()
)
if user is not None:
user.premium_tokens_reserved = max(
0, user.premium_tokens_reserved - reserved_tokens
user.premium_credit_micros_reserved = max(
0, user.premium_credit_micros_reserved - reserved_micros
)
await db_session.commit()
@ -598,9 +681,9 @@ class TokenQuotaService:
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
)
limit = user.premium_tokens_limit
used = user.premium_tokens_used
reserved = user.premium_tokens_reserved
limit = user.premium_credit_micros_limit
used = user.premium_credit_micros_used
reserved = user.premium_credit_micros_reserved
remaining = max(0, limit - used - reserved)
warning_threshold = int(limit * 0.8)

View file

@ -16,11 +16,14 @@ from __future__ import annotations
import dataclasses
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import litellm
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
@ -35,6 +38,8 @@ class TokenCallRecord:
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_micros: int = 0
call_kind: str = "chat"
@dataclass
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int = 0,
call_kind: str = "chat",
) -> None:
self.calls.append(
TokenCallRecord(
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
"""Return token counts grouped by model name."""
"""Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
{
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_micros": 0,
},
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
entry["cost_micros"] += c.cost_micros
return by_model
@property
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
@property
def total_cost_micros(self) -> int:
"""Sum of per-call ``cost_micros`` across the entire turn.
Used by ``stream_new_chat`` to debit a premium turn's actual
provider cost (in micro-USD) from the user's premium credit
balance. ``cost_micros`` per call is captured by
``TokenTrackingCallback.async_log_success_event`` from
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
with multiple fallback paths so OpenRouter dynamic models and
custom Azure deployments still bill correctly when our
``pricing_registration`` ran at startup.
"""
return sum(c.cost_micros for c in self.calls)
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it."""
"""Create a fresh accumulator for the current async context and return it.
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
short-lived per-call billable wrappers (image generation REST endpoint,
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
ContextVar reset token to restore the *previous* accumulator on exit and
avoids leaking call records across reservations (issue B).
"""
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
@asynccontextmanager
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
for the duration of the ``async with`` block, then *resets* the
ContextVar to its previous value on exit.
This is the safe primitive for per-call billable operations
(image generation, vision LLM extraction, podcasts) that may run
inside an outer chat turn or be called sequentially from the same
background worker. Using ``ContextVar.set`` without ``reset`` (as
:func:`start_turn` does) would leak the inner accumulator into the
outer scope, causing the outer chat turn to debit cost twice.
Usage::
async with scoped_turn() as acc:
await llm.ainvoke(...)
# acc.total_cost_micros captures cost from the LiteLLM callback
# Outer accumulator (if any) is restored here.
"""
acc = TurnTokenAccumulator()
token = _turn_accumulator.set(acc)
logger.debug(
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
id(acc),
token,
)
try:
yield acc
finally:
_turn_accumulator.reset(token)
logger.debug(
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
id(acc),
len(acc.calls),
acc.total_cost_micros,
)
def _extract_cost_usd(
kwargs: dict[str, Any],
response_obj: Any,
model: str,
prompt_tokens: int,
completion_tokens: int,
is_image: bool = False,
) -> float:
"""Best-effort USD cost extraction for a single LLM/image call.
Tries four sources in priority order and returns the first that
yields a positive number; returns 0.0 if all four fail (the call
will then debit nothing from the user's balance — fail-safe).
Sources:
1. ``kwargs["response_cost"]`` LiteLLM's standard callback
field, populated for ``Router.acompletion`` since PR #12500.
2. ``response_obj._hidden_params["response_cost"]`` same value
exposed on the response itself.
3. ``litellm.completion_cost(completion_response=response_obj)``
recompute from the response and LiteLLM's pricing table.
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
manual fallback for OpenRouter/custom-Azure models that
only resolve via aliases registered by
``pricing_registration`` at startup. **Skipped for image
responses** ``cost_per_token`` does not support ``ImageResponse``
and would raise; the cost map for image-gen lives in different
keys (``output_cost_per_image``) handled by ``completion_cost``.
"""
cost = kwargs.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
hidden = getattr(response_obj, "_hidden_params", None) or {}
if isinstance(hidden, dict):
cost = hidden.get("response_cost")
if cost is not None:
try:
value = float(cost)
except (TypeError, ValueError):
value = 0.0
if value > 0:
return value
try:
value = float(litellm.completion_cost(completion_response=response_obj))
if value > 0:
return value
except Exception as exc:
if is_image:
# Image-gen path: OpenRouter's image responses can omit
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
# then *raises* (no cost map for OpenRouter image models).
# Bail out with a warning rather than falling through to
# cost_per_token (which is also incompatible with ImageResponse).
logger.warning(
"[TokenTracking] completion_cost failed for image model=%s "
"(provider may have omitted usage.cost). Debiting 0. "
"Cause: %s",
model,
exc,
)
return 0.0
logger.debug(
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
)
if is_image:
# Never call cost_per_token for ImageResponse — keys mismatch and
# the function is documented chat-only.
return 0.0
if model and (prompt_tokens > 0 or completion_tokens > 0):
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
value = float(prompt_cost) + float(completion_cost)
if value > 0:
return value
except Exception as exc:
logger.debug(
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
)
return 0.0
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
)
return
# Detect image generation responses — they have a different usage
# shape (ImageUsage with input_tokens/output_tokens) and require a
# different cost-extraction path. We probe by class name to avoid a
# hard import dependency on litellm internals.
response_cls = type(response_obj).__name__
is_image = response_cls == "ImageResponse"
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug(
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
)
return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
if is_image:
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
# (not prompt_tokens/completion_tokens). Several providers
# populate only one or neither (e.g. OpenRouter's gpt-image-1
# passes through `input_tokens` from the prompt but no
# completion); fall through gracefully to 0.
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
completion_tokens = getattr(usage, "output_tokens", 0) or 0
total_tokens = (
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
)
call_kind = "image_generation"
else:
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
call_kind = "chat"
model = kwargs.get("model", "unknown")
cost_usd = _extract_cost_usd(
kwargs=kwargs,
response_obj=response_obj,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
is_image=is_image,
)
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
logger.warning(
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
"for image generation).",
model,
prompt_tokens,
completion_tokens,
call_kind,
)
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
call_kind=call_kind,
)
logger.info(
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
model,
call_kind,
prompt_tokens,
completion_tokens,
total_tokens,
cost_usd,
cost_micros,
len(acc.calls),
)
@ -168,6 +388,7 @@ async def record_token_usage(
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
cost_micros: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
@ -185,6 +406,7 @@ async def record_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
@ -194,11 +416,12 @@ async def record_token_usage(
)
session.add(record)
logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
usage_type,
prompt_tokens,
completion_tokens,
total_tokens,
cost_micros,
)
return record
except Exception:

View file

@ -3,6 +3,8 @@ from typing import Any
from litellm import Router
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
@ -108,10 +110,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
@ -120,8 +123,13 @@ class VisionLLMRouterService:
"api_key": config.get("api_key"),
}
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]

View file

@ -1,10 +1,25 @@
"""Celery tasks package."""
"""Celery tasks package.
Also hosts the small helpers every async celery task should use to
spin up its event loop. See :func:`run_async_celery_task` for the
canonical pattern.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable, Callable
from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
logger = logging.getLogger(__name__)
_celery_engine = None
_celery_session_maker = None
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False
)
return _celery_session_maker
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
"""Drop the shared ``app.db.engine`` connection pool synchronously.
The shared engine (used by ``shielded_async_session`` and most
routes / services) is a module-level singleton with a real pool.
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
connections cache a reference to whichever loop opened them. When
a subsequent task's loop pulls a stale connection from the pool,
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
AttributeError: 'NoneType' object has no attribute 'send'
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
self._write_fut = self._loop._proactor.send(self._sock, data)
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
coroutine that can never run because its loop is gone.
Disposing the engine forces the pool to drop every cached
connection so the next checkout opens a fresh one on the current
loop. Safe to call from a task's finally block; failure is logged
but never propagated.
"""
try:
from app.db import engine as shared_engine
loop.run_until_complete(shared_engine.dispose())
except Exception:
logger.warning("Shared DB engine dispose() failed", exc_info=True)
T = TypeVar("T")
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
"""Run an async coroutine inside a fresh event loop with proper
DB-engine cleanup.
This is the canonical entry point for every async celery task.
It performs three responsibilities that were previously copy-pasted
(incorrectly) across each task module:
1. Create a fresh ``asyncio`` loop and install it on the current
thread (celery's ``--pool=solo`` runs every task on the main
thread, but other pool types don't).
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
any stale connections left over from a previous task's loop
are dropped defends against tasks that crashed without
cleaning up.
3. Dispose the shared engine AFTER the task runs so the
connections we opened on this loop are released before the
loop closes (avoids ``coroutine 'Connection._cancel' was
never awaited`` warnings and the next-task hang).
Use as::
@celery_app.task(name="my_task", bind=True)
def my_task(self, *args):
return run_async_celery_task(lambda: _my_task_impl(*args))
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Defense-in-depth: prior task may have crashed before
# disposing. Idempotent — no-op if pool is already empty.
_dispose_shared_db_engine(loop)
return loop.run_until_complete(coro_factory())
finally:
# Drop any connections this task opened so they don't leak
# into the next task's loop.
_dispose_shared_db_engine(loop)
with contextlib.suppress(Exception):
loop.run_until_complete(loop.shutdown_asyncgens())
with contextlib.suppress(Exception):
asyncio.set_event_loop(None)
loop.close()
__all__ = [
"get_celery_session_maker",
"run_async_celery_task",
]

View file

@ -4,7 +4,7 @@ import logging
import traceback
from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str,
):
"""Celery task to index Notion pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_notion_pages(
return run_async_celery_task(
lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
finally:
loop.close()
async def _index_notion_pages(
@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str,
):
"""Celery task to index GitHub repositories."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_github_repos(
@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str,
):
"""Celery task to index Confluence pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_confluence_pages(
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str,
):
"""Celery task to index Google Calendar events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_calendar_events(
return run_async_celery_task(
lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
finally:
loop.close()
async def _index_google_calendar_events(
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str,
):
"""Celery task to index Google Gmail messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_google_gmail_messages(
@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
):
"""Celery task to index Google Drive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_google_drive_files(
@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict,
):
"""Celery task to index OneDrive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_onedrive_files(
@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict,
):
"""Celery task to index Dropbox folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_dropbox_files(
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str,
):
"""Celery task to index Elasticsearch documents."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_elasticsearch_documents(
@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str,
):
"""Celery task to index Web page Urls."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_crawled_urls(
return run_async_celery_task(
lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
finally:
loop.close()
async def _index_crawled_urls(
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str,
):
"""Celery task to index BookStack pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_bookstack_pages(
@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None,
):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_composio_connector(

View file

@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex
user_id: ID of user who edited the document
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reindex_document(document_id, user_id))
finally:
loop.close()
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
async def _reindex_document(document_id: int, user_id: str):

View file

@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder,
index_uploaded_files,
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
)
def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_document_background(document_id))
finally:
loop.close()
return run_async_celery_task(lambda: _delete_document_background(document_id))
async def _delete_document_background(document_id: int) -> None:
@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None,
):
"""Celery task to delete documents first, then the folder rows."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_delete_folder_documents(document_ids, folder_subtree_ids)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
)
async def _delete_folder_documents(
@ -209,12 +199,9 @@ async def _delete_folder_documents(
)
def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_search_space_background(search_space_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_search_space_background(search_space_id)
)
async def _delete_search_space_background(search_space_id: int) -> None:
@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space
user_id: ID of the user
"""
# Create a new event loop for this task
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_extension_document(
individual_document_dict, search_space_id, user_id
)
return run_async_celery_task(
lambda: _process_extension_document(
individual_document_dict, search_space_id, user_id
)
finally:
loop.close()
)
async def _process_extension_document(
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space
user_id: ID of the user
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _process_youtube_video(url, search_space_id, user_id)
)
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_upload(file_path, filename, search_space_id, user_id)
run_async_celery_task(
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
)
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _process_file_upload(
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start."
)
# Mark document as failed since file is missing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
run_async_celery_task(
lambda: _mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
finally:
loop.close()
)
return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_with_document(
run_async_celery_task(
lambda: _process_file_with_document(
document_id,
temp_path,
filename,
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _mark_document_failed(document_id: int, reason: str):
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
return run_async_celery_task(
lambda: _process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
finally:
loop.close()
)
async def _process_circleback_meeting(
@ -1291,25 +1246,19 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None,
):
"""Celery task to index a local folder. Config is passed directly — no connector row."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
return run_async_celery_task(
lambda: _index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
finally:
loop.close()
)
async def _index_local_folder_async(
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
return run_async_celery_task(
lambda: _index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
finally:
loop.close()
)
async def _index_uploaded_folder_files_async(
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_search_space_async(search_space_id, user_id)
)
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
)
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):

Some files were not shown because too many files have changed in this diff Show more