Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-05 01:44:46 +02:00
commit 5119915f4f
278 changed files with 34669 additions and 8970 deletions

View file

@ -144,6 +144,11 @@ jobs:
APPLE_ID: ${{ secrets.APPLE_ID }} APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
# TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
# Surfaces the exact codesign / notarize commands electron-builder spawns,
# so we can see which subprocess hangs.
DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
# Service principal credentials for Azure.Identity EnvironmentCredential used by the # Service principal credentials for Azure.Identity EnvironmentCredential used by the
# TrustedSigning PowerShell module. Only populated when signing is enabled. # TrustedSigning PowerShell module. Only populated when signing is enabled.
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing, # electron-builder 26 does not yet support OIDC federated tokens for Azure signing,

60
.github/workflows/notary-status.yml vendored Normal file
View file

@ -0,0 +1,60 @@
name: Notary status check
# One-off diagnostic workflow. Queries Apple's notary service to see if your
# submissions are queued, in progress, accepted, or rejected. Useful when a
# notarization seems "hung" — most often the queue itself, especially on a
# brand-new Apple Developer account.
#
# Run via: Actions tab -> "Notary status check" -> Run workflow.
# Inputs are optional; if you provide a submission ID, it also fetches that
# submission's full Apple log.
#
# Safe to delete after diagnosis.
on:
workflow_dispatch:
inputs:
submission_id:
description: 'Optional: submission UUID to fetch full Apple log for'
required: false
default: ''
jobs:
status:
runs-on: macos-latest
steps:
- name: List recent notarization submissions
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
run: |
set -euo pipefail
echo "::group::Submission history (most recent first)"
xcrun notarytool history \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID"
echo "::endgroup::"
- name: Inspect specific submission (if id provided)
if: ${{ inputs.submission_id != '' }}
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
SUBMISSION_ID: ${{ inputs.submission_id }}
run: |
set -euo pipefail
echo "::group::Submission info"
xcrun notarytool info "$SUBMISSION_ID" \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID"
echo "::endgroup::"
echo "::group::Apple's processing log for this submission"
xcrun notarytool log "$SUBMISSION_ID" \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID" || true
echo "::endgroup::"

31
.vscode/launch.json vendored
View file

@ -26,7 +26,16 @@
"pythonArgs": [ "pythonArgs": [
"run", "run",
"python" "python"
] ],
// Mute LangGraph/Pydantic checkpoint serializer warnings
// (UserWarnings emitted from pydantic/main.py when the
// runtime snapshots a SurfSenseContextSchema into a field
// typed `None`) so the debugger's "Raised Exceptions"
// breakpoint doesn't pause on a known-harmless event.
// Production logs are unaffected.
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
}, },
{ {
"name": "Backend: FastAPI (No Reload)", "name": "Backend: FastAPI (No Reload)",
@ -40,7 +49,10 @@
"pythonArgs": [ "pythonArgs": [
"run", "run",
"python" "python"
] ],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
}, },
{ {
"name": "Backend: FastAPI (main.py)", "name": "Backend: FastAPI (main.py)",
@ -54,7 +66,10 @@
"pythonArgs": [ "pythonArgs": [
"run", "run",
"python" "python"
] ],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
}, },
{ {
"name": "Frontend: Next.js", "name": "Frontend: Next.js",
@ -104,7 +119,10 @@
"pythonArgs": [ "pythonArgs": [
"run", "run",
"python" "python"
] ],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
}, },
{ {
"name": "Celery: Beat Scheduler", "name": "Celery: Beat Scheduler",
@ -124,7 +142,10 @@
"pythonArgs": [ "pythonArgs": [
"run", "run",
"python" "python"
] ],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
} }
], ],
"compounds": [ "compounds": [

View file

@ -1 +1 @@
0.0.19 0.0.21

View file

@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
# STRIPE_RECONCILIATION_BATCH_SIZE=100 # STRIPE_RECONCILIATION_BATCH_SIZE=100
# Premium token purchases ($1 per 1M tokens for premium-tier models) # Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
# credit; premium turns debit the actual per-call provider cost
# reported by LiteLLM, so cheap and expensive models bill proportionally)
# STRIPE_TOKEN_BUYING_ENABLED=FALSE # STRIPE_TOKEN_BUYING_ENABLED=FALSE
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... # STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
# STRIPE_TOKENS_PER_UNIT=1000000 # STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text) # TTS & STT (Text-to-Speech / Speech-to-Text)
@ -305,6 +308,24 @@ STT_SERVICE=local/base
# Advanced (optional) # Advanced (optional)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# New-chat agent feature flags
SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_BUSY_MUTEX=true
SURFSENSE_ENABLE_SKILLS=true
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
SURFSENSE_ENABLE_ACTION_LOG=true
SURFSENSE_ENABLE_REVERT_ROUTE=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
# Periodic connector sync interval (default: 5m) # Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m # SCHEDULE_CHECKER_INTERVAL=5m
@ -315,9 +336,24 @@ STT_SERVICE=local/base
# Pages limit per user for ETL (default: unlimited) # Pages limit per user for ETL (default: unlimited)
# PAGES_LIMIT=500 # PAGES_LIMIT=500
# Premium token quota per registered user (default: 5M) # Premium credit quota per registered user, in micro-USD (default: $5).
# Only applies to models with billing_tier=premium in global_llm_config.yaml # Premium turns are debited at the actual per-call provider cost reported
# PREMIUM_TOKEN_LIMIT=5000000 # by LiteLLM. Only applies to models with billing_tier=premium.
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
# QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation for the podcast Celery task ($0.20 default).
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation for the video Celery task ($1.00 default).
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — public users can chat without an account # No-login (anonymous) mode — public users can chat without an account
# Set TRUE to enable /free pages and anonymous chat API # Set TRUE to enable /free pages and anonymous chat API

View file

@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
# Set FALSE to disable new checkout session creation temporarily # Set FALSE to disable new checkout session creation temporarily
STRIPE_PAGE_BUYING_ENABLED=TRUE STRIPE_PAGE_BUYING_ENABLED=TRUE
# Premium token purchases via Stripe (for premium-tier model usage) # Premium credit purchases via Stripe (for premium-tier model usage).
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens) # Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
# per-call provider cost reported by LiteLLM.
STRIPE_TOKEN_BUYING_ENABLED=FALSE STRIPE_TOKEN_BUYING_ENABLED=FALSE
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
STRIPE_TOKENS_PER_UNIT=1000000 STRIPE_CREDIT_MICROS_PER_UNIT=1000000
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
# STRIPE_TOKENS_PER_UNIT=1000000
# Periodic Stripe safety net for purchases left in PENDING (minutes old) # Periodic Stripe safety net for purchases left in PENDING (minutes old)
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) # (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
PAGES_LIMIT=500 PAGES_LIMIT=500
# Premium token quota per registered user (default: 3,000,000) # Premium credit quota per registered user, in micro-USD
# Applies only to models with billing_tier=premium in global_llm_config.yaml # (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
PREMIUM_TOKEN_LIMIT=3000000 # actual per-call provider cost reported by LiteLLM, so cheap and expensive
# models bill proportionally. Applies only to models with
# billing_tier=premium in global_llm_config.yaml.
PREMIUM_CREDIT_MICROS_LIMIT=5000000
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
# PREMIUM_TOKEN_LIMIT=5000000
# Safety ceiling on per-call premium reservation, in micro-USD.
# stream_new_chat estimates an upper-bound cost from the model's
# litellm-published per-token rates × the config's quota_reserve_tokens
# and clamps to this value so a misconfigured model can't lock the
# user's whole balance on one call. Default $1.00.
QUOTA_MAX_RESERVE_MICROS=1000000
# Per-image reservation (in micro-USD) for the POST /image-generations
# endpoint. Bypassed for free configs. Default $0.05.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
# Single envelope covers one transcript-generation LLM call. Default $0.20.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
# Per-video-presentation reservation (in micro-USD) used by the video
# presentation Celery task. Covers worst-case fan-out of N slide-scene
# generations + refines. Default $1.00. NOTE: tasks using the override
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
# No-login (anonymous) mode — allows public users to chat without an account # No-login (anonymous) mode — allows public users to chat without an account
# Set TRUE to enable /free pages and anonymous chat API # Set TRUE to enable /free pages and anonymous chat API
@ -297,3 +327,30 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_PLUGIN_LOADER=false # SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names # Comma-separated allowlist of plugin entry-point names
# SURFSENSE_ALLOWED_PLUGINS=year_substituter # SURFSENSE_ALLOWED_PLUGINS=year_substituter
# -----------------------------------------------------------------------------
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
# -----------------------------------------------------------------------------
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
# on a cold turn) is reused across subsequent turns on the same thread,
# collapsing it to a microsecond hash lookup. All connector tools acquire
# their own short-lived DB session per call (Phase 2 refactor) so a cached
# closure is safe to share across requests. Flip OFF only as a last-resort
# rollback if you suspect cache-related staleness.
# SURFSENSE_ENABLE_AGENT_CACHE=true
# Cache capacity (max number of compiled-agent entries kept in memory)
# and TTL per entry (seconds). Working set is typically one entry per
# active thread on this replica; tune up for very large deployments.
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
# -----------------------------------------------------------------------------
# Connector discovery TTL cache (Phase 1.4 perf optimization)
# -----------------------------------------------------------------------------
# Caches the per-search-space "available connectors" + "available document
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
# ORM event listeners auto-invalidate on connector / document inserts,
# updates and deletes — the TTL only bounds staleness for bulk-import
# paths that bypass the ORM. Set to 0 to disable the cache.
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30

View file

@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
COPY pyproject.toml . COPY pyproject.toml .
COPY uv.lock . COPY uv.lock .
# Install PyTorch based on architecture # Install all Python dependencies from uv.lock for deterministic builds.
RUN if [ "$(uname -m)" = "x86_64" ]; then \ #
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \ # `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
else \ # which lets prod silently drift to newer upstream versions on every rebuild
pip install --no-cache-dir torch torchvision torchaudio; \ # (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
fi # Exporting the lock to requirements.txt and feeding it to `uv pip install`
# pins every transitive package to the exact version captured in uv.lock.
# Install python dependencies #
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
# downloads that the lock-based install immediately replaced. If a specific
# CUDA version is needed (driver compatibility, etc.), wire it through
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
RUN pip install --no-cache-dir uv && \ RUN pip install --no-cache-dir uv && \
uv pip install --system --no-cache-dir -e . uv export --frozen --no-dev --no-hashes --no-emit-project \
--format requirements-txt -o /tmp/requirements.txt && \
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
rm /tmp/requirements.txt
# Set SSL environment variables dynamically # Set SSL environment variables dynamically
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \ RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
# Pre-download Docling models # Pre-download Docling models
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
# Install Playwright browsers for web scraping if needed # Install Playwright browsers for web scraping (the playwright package itself
RUN pip install playwright && \ # is already installed via uv.lock above)
playwright install chromium --with-deps RUN playwright install chromium --with-deps
# Copy source code # Copy source code
COPY . . COPY . .
# Install the project itself in editable mode. Dependencies were already
# installed deterministically from uv.lock above, so --no-deps prevents any
# re-resolution that could pull newer versions.
RUN uv pip install --system --no-cache-dir --no-deps -e .
# Copy and set permissions for entrypoint script # Copy and set permissions for entrypoint script
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts) # Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh

View file

@ -0,0 +1,44 @@
"""138_add_thread_auto_model_pinning_fields
Revision ID: 138
Revises: 137
Create Date: 2026-04-30
Add a single thread-level column to persist the Auto (Fastest) model pin:
- pinned_llm_config_id: concrete resolved global LLM config id used for this
thread. NULL means "no pin; Auto will resolve on next turn".
The column is unindexed: all reads are by new_chat_threads.id (primary key),
so a secondary index would be dead write amplification.
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
revision: str = "138"
down_revision: str | None = "137"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
)
def downgrade() -> None:
# Drop any shape the thread row may be carrying. The extra columns and
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
# makes each statement a safe no-op on the lean shape.
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
)

View file

@ -0,0 +1,160 @@
"""add user table to zero_publication with column list
Adds the "user" table to zero_publication with a column-list publication
so that only the 5 fields driving the live usage meters are replicated
through WAL -> zero-cache -> browser IndexedDB:
id, pages_limit, pages_used,
premium_tokens_limit, premium_tokens_used
Sensitive columns (hashed_password, email, oauth_account, display_name,
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
included in the publication, so they never enter WAL replication.
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
(it is already DEFAULT today since "user" was never in the
TABLES_WITH_FULL_IDENTITY list of migration 117).
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Revision ID: 139
Revises: 138
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "139"
down_revision: str | None = "138"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Document column list as left by migration 117. Must match exactly.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
# Buy Tokens content). Keep this list narrow on purpose: anything added
# here flows into WAL and IndexedDB for every connected browser.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _build_publication_ddl(
documents_has_zero_ver: bool, user_has_zero_ver: bool
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state"
)
def upgrade() -> None:
conn = op.get_bind()
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
# opened one via context.begin_transaction(), but the driver still errors
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
# migration 117, so this is already DEFAULT. Re-assert anyway so
# the column-list publication stays valid (DEFAULT identity only
# requires the PK to be in the column list).
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))

View file

@ -0,0 +1,291 @@
"""rename premium token columns to credit micros and add cost_micros to token_usage
Migrates the premium quota system from a flat token counter to a USD-cost
based credit system, where 1 credit = 1 micro-USD ($0.000001).
Column renames (1:1 numerical mapping the prior $1 per 1M tokens Stripe
price means every existing value is already correct in the new unit, no
data transformation needed):
user.premium_tokens_limit -> premium_credit_micros_limit
user.premium_tokens_used -> premium_credit_micros_used
user.premium_tokens_reserved -> premium_credit_micros_reserved
premium_token_purchases.tokens_granted -> credit_micros_granted
New column for cost auditing per turn:
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
The "user" table is in zero_publication's column list (added in 139), so
this migration must drop and recreate the publication with the renamed
column names, otherwise zero-cache will replicate stale column names and
the FE Zero schema will fail to bind.
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
column-not-found errors from a stale catalog snapshot.
Revision ID: 140
Revises: 139
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "140"
down_revision: str | None = "139"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Replicates 139's document column list verbatim — must stay in sync.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Same five live-meter fields as 139, with the renamed column names.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_credit_micros_limit",
"premium_credit_micros_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _column_exists(conn, table: str, column: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = :col"
),
{"tbl": table, "col": column},
).fetchone()
is not None
)
def _build_publication_ddl(
user_cols: list[str],
*,
documents_has_zero_ver: bool,
user_has_zero_ver: bool,
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_col_list_with_meta = user_cols + (
['"_0_version"'] if user_has_zero_ver else []
)
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_col_list_with_meta)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def upgrade() -> None:
conn = op.get_bind()
# ------------------------------------------------------------------
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
# dev environments are safe.
# ------------------------------------------------------------------
if not _column_exists(conn, "token_usage", "cost_micros"):
op.add_column(
"token_usage",
sa.Column(
"cost_micros",
sa.BigInteger(),
nullable=False,
server_default="0",
),
)
# ------------------------------------------------------------------
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
# ------------------------------------------------------------------
if _column_exists(
conn, "premium_token_purchases", "tokens_granted"
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
op.alter_column(
"premium_token_purchases",
"tokens_granted",
new_column_name="credit_micros_granted",
)
# ------------------------------------------------------------------
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
#
# We must drop the publication first (it references the old column
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
# in a transaction block; alembic's outer transaction already holds
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
# ------------------------------------------------------------------
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
# publications require at least the PK to be in the column list,
# which is true for both the old and new shape.
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
# Drop the publication BEFORE renaming columns, otherwise Postgres
# rejects the rename: "cannot drop column ... referenced by
# publication".
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for old, new in (
("premium_tokens_limit", "premium_credit_micros_limit"),
("premium_tokens_used", "premium_credit_micros_used"),
("premium_tokens_reserved", "premium_credit_micros_reserved"),
):
if _column_exists(conn, "user", old) and not _column_exists(
conn, "user", new
):
op.alter_column("user", old, new_column_name=new)
# Update the server_default on the renamed limit column so newly
# inserted users get $5 of credit (== 5_000_000 micros) by
# default. Existing rows are unaffected.
op.alter_column(
"user",
"premium_credit_micros_limit",
server_default="5000000",
)
# Recreate the publication with the new column names.
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
USER_COLS,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
def downgrade() -> None:
"""Revert the rename and drop ``cost_micros``.
Mirrors ``upgrade``: drop the publication, rename columns back, drop
the new column, recreate the publication with the old column list.
Same zero-cache stop/reset runbook applies in reverse.
"""
conn = op.get_bind()
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
for new, old in (
("premium_credit_micros_limit", "premium_tokens_limit"),
("premium_credit_micros_used", "premium_tokens_used"),
("premium_credit_micros_reserved", "premium_tokens_reserved"),
):
if _column_exists(conn, "user", new) and not _column_exists(
conn, "user", old
):
op.alter_column("user", new, new_column_name=old)
op.alter_column(
"user",
"premium_tokens_limit",
server_default="5000000",
)
legacy_user_cols = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(
_build_publication_ddl(
legacy_user_cols,
documents_has_zero_ver=documents_has_zero_ver,
user_has_zero_ver=user_has_zero_ver,
)
)
)
if _column_exists(
conn, "premium_token_purchases", "credit_micros_granted"
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
op.alter_column(
"premium_token_purchases",
"credit_micros_granted",
new_column_name="tokens_granted",
)
if _column_exists(conn, "token_usage", "cost_micros"):
op.drop_column("token_usage", "cost_micros")

View file

@ -0,0 +1,66 @@
"""141_unique_chat_message_turn_role
Revision ID: 141
Revises: 140
Create Date: 2026-05-04
Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)``
where ``turn_id IS NOT NULL``.
Why
---
The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being
moved to write its own ``new_chat_messages`` rows server-side instead of
relying on the frontend's later ``POST /threads/{id}/messages`` call. This
closes the "ghost-thread" abuse vector where authenticated callers got free
LLM completions while ``new_chat_messages`` stayed empty.
For server-side and legacy frontend writes to coexist we need an idempotency
key. The natural triple is ``(thread_id, turn_id, role)``: the server issues
exactly one ``turn_id`` per turn, and a turn produces at most one user
message and one assistant message. Whichever side wins the race writes the
row; the loser hits ``IntegrityError`` and recovers gracefully.
Partial ``WHERE turn_id IS NOT NULL`` so:
* Legacy rows that predate the ``turn_id`` column (migration 136) keep
co-existing without de-dup.
* Clone / snapshot inserts in
``app/services/public_chat_service.py`` that build ``NewChatMessage``
without ``turn_id`` are unaffected (multiple snapshot copies of the same
user/assistant pair are intentional).
This index coexists with the existing single-column ``ix_new_chat_messages_turn_id``
from migration 136 no collision.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "141"
down_revision: str | None = "140"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
INDEX_NAME = "uq_new_chat_messages_thread_turn_role"
TABLE_NAME = "new_chat_messages"
def upgrade() -> None:
op.create_index(
INDEX_NAME,
TABLE_NAME,
["thread_id", "turn_id", "role"],
unique=True,
postgresql_where=sa.text("turn_id IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)

View file

@ -0,0 +1,134 @@
"""142_token_usage_message_id_unique
Revision ID: 142
Revises: 141
Create Date: 2026-05-04
Add a partial unique index on ``token_usage(message_id)`` where
``message_id IS NOT NULL``.
Why
---
Two writers can race on the same assistant turn's ``token_usage`` row:
* ``finalize_assistant_turn`` (server-side, called from the streaming
finally block in ``stream_new_chat`` / ``stream_resume_chat``)
* ``append_message``'s recovery branch in
``app/routes/new_chat_routes.py`` (legacy frontend round-trip)
Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a
micro-second-aligned race could observe "no row" on each side and double
INSERT, producing duplicate ``token_usage`` rows for the same
``message_id``.
A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``)
turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING``
no-ops for the loser, hard-eliminating the race at the DB level. Partial
because non-chat usage rows (indexing, image generation, podcasts) keep
``message_id`` NULL they're per-event, no de-dup needed.
Pre-flight
----------
Today's schema only has a non-unique index on ``message_id`` so a
duplicate population could already exist from any past race. We:
* Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``).
* If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50)
we dedupe by deleting all but the smallest ``id`` per group.
* If the count exceeds the threshold we abort with a descriptive
error rather than silently mutate prod data operator must
investigate before retrying.
Concurrency
-----------
``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid
stalling production writes during deploy (a regular ``CREATE INDEX``
holds an ACCESS EXCLUSIVE lock for the duration of the build, which
would block ``token_usage`` INSERTs for every active streaming chat).
The trade-off is a slower migration (CONCURRENTLY scans the table
twice) and the ``CREATE`` statement cannot run inside alembic's default
transaction wrapper ``autocommit_block()`` handles that.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "142"
down_revision: str | None = "141"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
INDEX_NAME = "uq_token_usage_message_id"
TABLE_NAME = "token_usage"
# Refuse to silently mutate prod data if the duplicate population is
# unexpectedly large — operator should investigate the upstream cause
# before retrying. 50 is comfortably above any plausible duplicate
# count from the existing race window (the race is microseconds wide).
DUPLICATE_ABORT_THRESHOLD = 50
def upgrade() -> None:
conn = op.get_bind()
dup_groups = conn.execute(
sa.text(
"SELECT message_id, COUNT(*) AS n "
"FROM token_usage "
"WHERE message_id IS NOT NULL "
"GROUP BY message_id "
"HAVING COUNT(*) > 1"
)
).fetchall()
if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD:
raise RuntimeError(
f"token_usage has {len(dup_groups)} duplicate message_id groups "
f"(threshold={DUPLICATE_ABORT_THRESHOLD}). "
"Resolve the duplicates manually before re-running this migration."
)
if dup_groups:
# Delete all but the smallest-id row per duplicate group. The
# smallest id is by definition the earliest insert, so we keep
# the row most likely to reflect the actual stream's first
# successful write.
conn.execute(
sa.text(
"""
DELETE FROM token_usage
WHERE id IN (
SELECT id FROM (
SELECT
id,
row_number() OVER (
PARTITION BY message_id ORDER BY id ASC
) AS rn
FROM token_usage
WHERE message_id IS NOT NULL
) ranked
WHERE rn > 1
)
"""
)
)
# CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop
# alembic's auto-transaction for this op only.
with op.get_context().autocommit_block():
op.execute(
f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} "
f"ON {TABLE_NAME} (message_id) "
"WHERE message_id IS NOT NULL"
)
def downgrade() -> None:
with op.get_context().autocommit_block():
op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}")

View file

@ -11,7 +11,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from .middleware import build_main_agent_deepagent_middleware
from app.agents.multi_agent_chat.subagents.shared.permissions import ( from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions, ToolsPermissions,
) )
@ -20,6 +19,8 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility from app.db import ChatVisibility
from .middleware import build_main_agent_deepagent_middleware
def build_compiled_agent_graph_sync( def build_compiled_agent_graph_sync(
*, *,

View file

@ -31,8 +31,8 @@ from .propagation import (
from .resume import ( from .resume import (
build_resume_command, build_resume_command,
fan_out_decisions_to_match, fan_out_decisions_to_match,
hitlrequest_action_count,
get_first_pending_subagent_interrupt, get_first_pending_subagent_interrupt,
hitlrequest_action_count,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,7 +51,9 @@ def build_task_tool_with_parent_config(
) )
if task_description is None: if task_description is None:
description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str) description = TASK_TOOL_DESCRIPTION.format(
available_agents=subagent_description_str
)
elif "{available_agents}" in task_description: elif "{available_agents}" in task_description:
description = task_description.format(available_agents=subagent_description_str) description = task_description.format(available_agents=subagent_description_str)
else: else:
@ -90,11 +92,11 @@ def build_task_tool_with_parent_config(
def task( def task(
description: Annotated[ description: Annotated[
str, str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
], ],
subagent_type: Annotated[ subagent_type: Annotated[
str, str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 "The type of subagent to use. Must be one of the available agent types listed in the tool description.",
], ],
runtime: ToolRuntime, runtime: ToolRuntime,
) -> str | Command: ) -> str | Command:
@ -119,7 +121,9 @@ def build_task_tool_with_parent_config(
if callable(get_state): if callable(get_state):
try: try:
snapshot = get_state(sub_config) snapshot = get_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) pending_id, pending_value = get_first_pending_subagent_interrupt(
snapshot
)
except Exception: except Exception:
# Fail loud if a resume is queued: silent fallback would # Fail loud if a resume is queued: silent fallback would
# replay the original interrupt to the user. # replay the original interrupt to the user.
@ -158,11 +162,11 @@ def build_task_tool_with_parent_config(
async def atask( async def atask(
description: Annotated[ description: Annotated[
str, str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
], ],
subagent_type: Annotated[ subagent_type: Annotated[
str, str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 "The type of subagent to use. Must be one of the available agent types listed in the tool description.",
], ],
runtime: ToolRuntime, runtime: ToolRuntime,
) -> str | Command: ) -> str | Command:
@ -186,7 +190,9 @@ def build_task_tool_with_parent_config(
if callable(aget_state): if callable(aget_state):
try: try:
snapshot = await aget_state(sub_config) snapshot = await aget_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) pending_id, pending_value = get_first_pending_subagent_interrupt(
snapshot
)
except Exception: except Exception:
if has_surfsense_resume(runtime): if has_surfsense_resume(runtime):
logger.exception( logger.exception(

View file

@ -23,7 +23,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from ...context_prune.prune_tool_names import safe_exclude_tools
from app.agents.multi_agent_chat.subagents import ( from app.agents.multi_agent_chat.subagents import (
build_subagents, build_subagents,
get_subagents_to_exclude, get_subagents_to_exclude,
@ -66,6 +65,7 @@ from app.agents.new_chat.plugin_loader import (
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from app.db import ChatVisibility from app.db import ChatVisibility
from ...context_prune.prune_tool_names import safe_exclude_tools
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware

View file

@ -14,8 +14,10 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync from app.agents.multi_agent_chat.subagents import (
from ..tools import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED get_subagents_to_exclude,
main_prompt_registry_subagent_lines,
)
from app.agents.multi_agent_chat.subagents.mcp_tools.index import ( from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
load_mcp_tools_by_connector, load_mcp_tools_by_connector,
) )
@ -24,17 +26,19 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.llm_config import AgentConfig
from app.agents.multi_agent_chat.subagents import (
get_subagents_to_exclude,
main_prompt_registry_subagent_lines,
)
from ..system_prompt import build_main_agent_system_prompt
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
from app.agents.new_chat.tools.registry import build_tools_async from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility from app.db import ChatVisibility
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
from ..system_prompt import build_main_agent_system_prompt
from ..tools import (
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
)
_perf_log = get_perf_logger() _perf_log = get_perf_logger()

View file

@ -2,6 +2,9 @@
from __future__ import annotations from __future__ import annotations
from .index import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED from .index import (
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
)
__all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"] __all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"]

View file

@ -13,7 +13,9 @@ from .resume import create_generate_resume_tool
from .video_presentation import create_generate_video_presentation_tool from .video_presentation import create_generate_video_presentation_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs} resolved_dependencies = {**(dependencies or {}), **kwargs}
podcast = create_generate_podcast_tool( podcast = create_generate_podcast_tool(
search_space_id=resolved_dependencies["search_space_id"], search_space_id=resolved_dependencies["search_space_id"],

View file

@ -10,7 +10,9 @@ from app.db import ChatVisibility
from .update_memory import create_update_memory_tool, create_update_team_memory_tool from .update_memory import create_update_memory_tool, create_update_team_memory_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs} resolved_dependencies = {**(dependencies or {}), **kwargs}
if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE: if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE:
mem = create_update_team_memory_tool( mem = create_update_team_memory_tool(
@ -18,7 +20,10 @@ def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) ->
db_session=resolved_dependencies["db_session"], db_session=resolved_dependencies["db_session"],
llm=resolved_dependencies.get("llm"), llm=resolved_dependencies.get("llm"),
) )
return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []} return {
"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}],
"ask": [],
}
mem = create_update_memory_tool( mem = create_update_memory_tool(
user_id=resolved_dependencies["user_id"], user_id=resolved_dependencies["user_id"],
db_session=resolved_dependencies["db_session"], db_session=resolved_dependencies["db_session"],

View file

@ -11,14 +11,20 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
from .web_search import create_web_search_tool from .web_search import create_web_search_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs} resolved_dependencies = {**(dependencies or {}), **kwargs}
web = create_web_search_tool( web = create_web_search_tool(
search_space_id=resolved_dependencies.get("search_space_id"), search_space_id=resolved_dependencies.get("search_space_id"),
available_connectors=resolved_dependencies.get("available_connectors"), available_connectors=resolved_dependencies.get("available_connectors"),
) )
scrape = create_scrape_webpage_tool(firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key")) scrape = create_scrape_webpage_tool(
docs = create_search_surfsense_docs_tool(db_session=resolved_dependencies["db_session"]) firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key")
)
docs = create_search_surfsense_docs_tool(
db_session=resolved_dependencies["db_session"]
)
return { return {
"allow": [ "allow": [
{"name": getattr(web, "name", "") or "", "tool": web}, {"name": getattr(web, "name", "") or "", "tool": web},

View file

@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
) )
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs} _ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []} return {"allow": [], "ask": []}

View file

@ -12,7 +12,9 @@ from .search_events import create_search_calendar_events_tool
from .update_event import create_update_calendar_event_tool from .update_event import create_update_calendar_event_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs} resolved_dependencies = {**(dependencies or {}), **kwargs}
session_dependencies = { session_dependencies = {
"db_session": resolved_dependencies["db_session"], "db_session": resolved_dependencies["db_session"],

View file

@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
) )
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs} _ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []} return {"allow": [], "ask": []}

View file

@ -11,7 +11,9 @@ from .delete_page import create_delete_confluence_page_tool
from .update_page import create_update_confluence_page_tool from .update_page import create_update_confluence_page_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
resolved_dependencies = {**(dependencies or {}), **kwargs} resolved_dependencies = {**(dependencies or {}), **kwargs}
session_dependencies = { session_dependencies = {
"db_session": resolved_dependencies["db_session"], "db_session": resolved_dependencies["db_session"],

View file

@ -11,7 +11,9 @@ from .read_messages import create_read_discord_messages_tool
from .send_message import create_send_discord_message_tool from .send_message import create_send_discord_message_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -10,7 +10,9 @@ from .create_file import create_create_dropbox_file_tool
from .trash_file import create_delete_dropbox_file_tool from .trash_file import create_delete_dropbox_file_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -14,7 +14,9 @@ from .trash_email import create_trash_gmail_email_tool
from .update_draft import create_update_gmail_draft_tool from .update_draft import create_update_gmail_draft_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -10,7 +10,9 @@ from .create_file import create_create_google_drive_file_tool
from .trash_file import create_delete_google_drive_file_tool from .trash_file import create_delete_google_drive_file_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -11,7 +11,9 @@ from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool from .update_issue import create_update_jira_issue_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -11,7 +11,9 @@ from .delete_issue import create_delete_linear_issue_tool
from .update_issue import create_update_linear_issue_tool from .update_issue import create_update_linear_issue_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -11,7 +11,9 @@ from .list_events import create_list_luma_events_tool
from .read_event import create_read_luma_event_tool from .read_event import create_read_luma_event_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -11,7 +11,9 @@ from .delete_page import create_delete_notion_page_tool
from .update_page import create_update_notion_page_tool from .update_page import create_update_notion_page_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -10,7 +10,9 @@ from .create_file import create_create_onedrive_file_tool
from .trash_file import create_delete_onedrive_file_tool from .trash_file import create_delete_onedrive_file_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
) )
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
_ = {**(dependencies or {}), **kwargs} _ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []} return {"allow": [], "ask": []}

View file

@ -11,7 +11,9 @@ from .read_messages import create_read_teams_messages_tool
from .send_message import create_send_teams_message_tool from .send_message import create_send_teams_message_tool
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions: def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs} d = {**(dependencies or {}), **kwargs}
common = { common = {
"db_session": d["db_session"], "db_session": d["db_session"],

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
## Helper functions for fetching connector metadata maps ## Helper functions for fetching connector metadata maps
async def fetch_mcp_connector_metadata_maps( async def fetch_mcp_connector_metadata_maps(
session: AsyncSession, session: AsyncSession,
search_space_id: int, search_space_id: int,
@ -58,6 +59,7 @@ async def fetch_mcp_connector_metadata_maps(
## Helper functions for partitioning tools by connector agent ## Helper functions for partitioning tools by connector agent
def partition_mcp_tools_by_connector( def partition_mcp_tools_by_connector(
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
connector_id_to_type: dict[int, str], connector_id_to_type: dict[int, str],
@ -104,8 +106,10 @@ def partition_mcp_tools_by_connector(
return dict(buckets) return dict(buckets)
## Helper functions for splitting tools by permissions ## Helper functions for splitting tools by permissions
def _get_mcp_tool_name(tool: BaseTool) -> str: def _get_mcp_tool_name(tool: BaseTool) -> str:
meta: dict[str, Any] = getattr(tool, "metadata", None) or {} meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
orig = meta.get("mcp_original_tool_name") orig = meta.get("mcp_original_tool_name")
@ -139,6 +143,7 @@ def _split_tools_by_permissions(
## Main function to load MCP tools and split them by permissions for each connector agent ## Main function to load MCP tools and split them by permissions for each connector agent
async def load_mcp_tools_by_connector( async def load_mcp_tools_by_connector(
session: AsyncSession, session: AsyncSession,
search_space_id: int, search_space_id: int,
@ -148,9 +153,7 @@ async def load_mcp_tools_by_connector(
Pass ``bypass_internal_hitl=True`` so the subagent's Pass ``bypass_internal_hitl=True`` so the subagent's
``HumanInTheLoopMiddleware`` is the single HITL gate. ``HumanInTheLoopMiddleware`` is the single HITL gate.
""" """
flat = await load_mcp_tools( flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True)
session, search_space_id, bypass_internal_hitl=True
)
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id) id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map) buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
return { return {

View file

@ -8,6 +8,9 @@ from typing import Any, Protocol
from deepagents import SubAgent from deepagents import SubAgent
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from app.agents.multi_agent_chat.constants import (
SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
)
from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import ( from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import (
build_subagent as build_deliverables_subagent, build_subagent as build_deliverables_subagent,
) )
@ -62,9 +65,6 @@ from app.agents.multi_agent_chat.subagents.connectors.slack.agent import (
from app.agents.multi_agent_chat.subagents.connectors.teams.agent import ( from app.agents.multi_agent_chat.subagents.connectors.teams.agent import (
build_subagent as build_teams_subagent, build_subagent as build_teams_subagent,
) )
from app.agents.multi_agent_chat.constants import (
SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
)
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
read_md_file, read_md_file,
) )
@ -105,6 +105,7 @@ SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = {
"teams": build_teams_subagent, "teams": build_teams_subagent,
} }
def _route_resource_package(builder: SubagentBuilder) -> str: def _route_resource_package(builder: SubagentBuilder) -> str:
mod = builder.__module__ mod = builder.__module__
return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0] return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0]

View file

@ -0,0 +1,357 @@
"""TTL-LRU cache for compiled SurfSense deep agents.
Why this exists
---------------
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
turn:
1. Discover connectors & document types from Postgres (~50-200ms)
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
3. Compose the system prompt
4. Construct ~15 middleware instances (CPU)
5. Eagerly compile the general-purpose subagent
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
which builds a second LangGraph + Pydantic schemas ~1.5-2s of pure
CPU work)
6. Compile the outer LangGraph
For a single thread, all six steps produce the SAME object on every turn
unless the user has changed their LLM config, toggled a feature flag,
added a connector, etc. The right answer is to compile ONCE per
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
every subsequent turn on the same thread.
Why a per-thread key (not a global pool)
----------------------------------------
Most middleware in the SurfSense stack captures per-thread state in
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
would silently leak state across users and threads. Keying the cache on
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
turns on the same thread without changing any middleware's behavior.
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
(read via ``runtime.context``) so the cache can collapse to a single
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
then, per-thread keying is the only safe option.
Cache shape
-----------
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
minutes matches a typical chat session). ``maxsize`` (default 256)
caps memory; LRU evicts least-recently-used on overflow.
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
cold misses on the same key wait for the first build instead of
building N times.
* Process-local: this is an in-memory cache. Multi-replica deployments
pay the build cost once per replica per key. That's fine; the working
set per replica is small (one entry per active thread on that replica).
Telemetry
---------
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
* ``hit`` cache hit, microseconds-fast
* ``miss`` first build for this key, includes build duration
* ``stale`` entry was found but expired; rebuilt
* ``evict`` LRU eviction (size-limited)
* ``size`` current cache occupancy at lookup time
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# ---------------------------------------------------------------------------
# Public API: signature helpers (cache key components)
# ---------------------------------------------------------------------------
def stable_hash(*parts: Any) -> str:
"""Compute a deterministic SHA1 of the str repr of ``parts``.
Used for cache key components that need a fixed-width representation
(system prompt, tool list, etc.). SHA1 is fine here this is not a
security boundary, just a content fingerprint.
"""
h = hashlib.sha1(usedforsecurity=False)
for p in parts:
h.update(repr(p).encode("utf-8", errors="replace"))
h.update(b"\x1f") # ASCII unit separator between parts
return h.hexdigest()
def tools_signature(
tools: list[Any] | tuple[Any, ...],
*,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> str:
"""Hash the bound-tool surface for cache-key purposes.
The signature changes whenever:
* A tool is added or removed from the bound list (built-in toggles,
MCP tools loaded for the user changes, gating rules flip, etc.).
* The available connectors / document types for the search space
change (new connector added, last connector removed, new document
type indexed). Because :func:`get_connector_gated_tools` derives
``modified_disabled_tools`` from ``available_connectors``, the
tool surface is technically already covered but we hash the
connector list separately so an empty-list "no tools changed"
situation still rotates the key when, say, the user re-adds a
connector that gates a tool we were already not exposing.
Stays stable across:
* Process restarts (tool names + descriptions are static).
* Different replicas (everyone gets the same hash for the same
inputs).
"""
tool_descriptors = sorted(
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
)
connectors = sorted(available_connectors or [])
doc_types = sorted(available_document_types or [])
return stable_hash(tool_descriptors, connectors, doc_types)
def flags_signature(flags: Any) -> str:
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
Frozen dataclasses are deterministically reprable, so a SHA1 of their
repr is a stable fingerprint. Restart safe (flags are read once at
process boot).
"""
return stable_hash(repr(flags))
def system_prompt_hash(system_prompt: str) -> str:
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
return hashlib.sha1(
system_prompt.encode("utf-8", errors="replace"),
usedforsecurity=False,
).hexdigest()
# ---------------------------------------------------------------------------
# Cache implementation
# ---------------------------------------------------------------------------
@dataclass
class _Entry:
value: Any
created_at: float
last_used_at: float
class _AgentCache:
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
NOT THREAD-SAFE in the multithreading sense designed for a single
asyncio event loop. Uvicorn runs one event loop per worker process,
so this is fine; multi-worker deployments simply each maintain their
own cache.
"""
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
self._maxsize = maxsize
self._ttl = ttl_seconds
self._entries: OrderedDict[str, _Entry] = OrderedDict()
# One lock per key — guards "build" so concurrent cold misses on
# the same key wait for the first build instead of all racing.
self._locks: dict[str, asyncio.Lock] = {}
def _now(self) -> float:
return time.monotonic()
def _is_fresh(self, entry: _Entry) -> bool:
return (self._now() - entry.created_at) < self._ttl
def _evict_if_full(self) -> None:
while len(self._entries) >= self._maxsize:
evicted_key, _ = self._entries.popitem(last=False)
self._locks.pop(evicted_key, None)
_perf_log.info(
"[agent_cache] evict key=%s reason=lru size=%d",
_short(evicted_key),
len(self._entries),
)
def _touch(self, key: str, entry: _Entry) -> None:
entry.last_used_at = self._now()
self._entries.move_to_end(key, last=True)
async def get_or_build(
self,
key: str,
*,
builder: Callable[[], Awaitable[Any]],
) -> Any:
"""Return the cached value for ``key`` or call ``builder()`` to make it.
``builder`` MUST be idempotent concurrent cold misses on the
same key collapse to a single ``builder()`` call (the others
wait on the in-flight lock and observe the populated entry on
wake).
"""
# Fast path: hot hit.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
# Stale entry — drop it; rebuild below.
if entry is not None and not self._is_fresh(entry):
_perf_log.info(
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
_short(key),
self._now() - entry.created_at,
self._ttl,
)
self._entries.pop(key, None)
# Slow path: serialize concurrent misses for the same key.
lock = self._locks.setdefault(key, asyncio.Lock())
async with lock:
# Double-check after acquiring the lock — another waiter may
# have populated the entry while we slept.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
t0 = time.perf_counter()
try:
value = await builder()
except BaseException:
# Don't cache failed builds; let the next caller retry.
_perf_log.warning(
"[agent_cache] build_failed key=%s elapsed=%.3fs",
_short(key),
time.perf_counter() - t0,
)
raise
elapsed = time.perf_counter() - t0
# Insert + evict.
self._evict_if_full()
now = self._now()
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
self._entries.move_to_end(key, last=True)
_perf_log.info(
"[agent_cache] miss key=%s build=%.3fs size=%d",
_short(key),
elapsed,
len(self._entries),
)
return value
def invalidate(self, key: str) -> bool:
"""Drop a single entry; return True if anything was removed."""
removed = self._entries.pop(key, None) is not None
self._locks.pop(key, None)
if removed:
_perf_log.info(
"[agent_cache] invalidate key=%s size=%d",
_short(key),
len(self._entries),
)
return removed
def invalidate_prefix(self, prefix: str) -> int:
"""Drop every entry whose key starts with ``prefix``. Returns count."""
keys = [k for k in self._entries if k.startswith(prefix)]
for k in keys:
self._entries.pop(k, None)
self._locks.pop(k, None)
if keys:
_perf_log.info(
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
_short(prefix),
len(keys),
len(self._entries),
)
return len(keys)
def clear(self) -> None:
n = len(self._entries)
self._entries.clear()
self._locks.clear()
if n:
_perf_log.info("[agent_cache] clear removed=%d", n)
def stats(self) -> dict[str, Any]:
return {
"size": len(self._entries),
"maxsize": self._maxsize,
"ttl_seconds": self._ttl,
}
def _short(key: str, n: int = 16) -> str:
"""Truncate keys for log lines so they don't blow up log volume."""
return key if len(key) <= n else f"{key[:n]}..."
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
def get_cache() -> _AgentCache:
"""Return the process-wide compiled-agent cache singleton."""
return _cache
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
"""Replace the singleton with a fresh cache. Tests only."""
global _cache
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
return _cache
__all__ = [
"flags_signature",
"get_cache",
"reload_for_tests",
"stable_hash",
"system_prompt_hash",
"tools_signature",
]

View file

@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable
subclass of the default ``FilesystemMiddleware`` while preserving every subclass of the default ``FilesystemMiddleware`` while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents, other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
summarisation, prompt-caching, etc.). summarisation, etc.). Prompt caching is configured at LLM-build time via
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
than as a middleware.
""" """
import asyncio import asyncio
@ -33,12 +35,18 @@ from langchain.agents.middleware import (
TodoListMiddleware, TodoListMiddleware,
ToolCallLimitMiddleware, ToolCallLimitMiddleware,
) )
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_backends import build_backend_resolver
@ -52,6 +60,7 @@ from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware, DedupHITLToolCallsMiddleware,
DoomLoopMiddleware, DoomLoopMiddleware,
FileIntentMiddleware, FileIntentMiddleware,
FlattenSystemMessageMiddleware,
KnowledgeBasePersistenceMiddleware, KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware, KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware, KnowledgeTreeMiddleware,
@ -74,6 +83,7 @@ from app.agents.new_chat.plugin_loader import (
load_allowed_plugin_names_from_env, load_allowed_plugin_names_from_env,
load_plugin_middlewares, load_plugin_middlewares,
) )
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.subagents import build_specialized_subagents
from app.agents.new_chat.system_prompt import ( from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt, build_configurable_system_prompt,
@ -94,6 +104,39 @@ from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
def _resolve_prompt_model_name(
agent_config: AgentConfig | None,
llm: BaseChatModel,
) -> str | None:
"""Resolve the model id to feed to provider-variant detection.
Preference order (matches the established idiom in
``llm_router_service.py`` see ``params.get("base_model") or
params.get("model", "")`` usages there):
1. ``agent_config.litellm_params["base_model"]`` required for Azure
deployments where ``model_name`` is the deployment slug, not the
underlying family. Without this, a deployment named e.g.
``"prod-chat-001"`` would silently miss every provider regex.
2. ``agent_config.model_name`` the user's configured model id.
3. ``getattr(llm, "model", None)`` fallback for direct callers that
don't supply an ``AgentConfig`` (currently a defensive path; all
production callers pass ``agent_config``).
Returns ``None`` when nothing is available; ``compose_system_prompt``
treats that as the ``"default"`` variant (no provider block emitted).
"""
if agent_config is not None:
params = agent_config.litellm_params or {}
base_model = params.get("base_model")
if isinstance(base_model, str) and base_model.strip():
return base_model
if agent_config.model_name:
return agent_config.model_name
return getattr(llm, "model", None)
# ============================================================================= # =============================================================================
# Connector Type Mapping # Connector Type Mapping
# ============================================================================= # =============================================================================
@ -279,6 +322,14 @@ async def create_surfsense_deep_agent(
) )
""" """
_t_agent_total = time.perf_counter() _t_agent_total = time.perf_counter()
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
# build-time call in ``llm_config.py``; this run merely adds
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
# None or the provider is non-OpenAI-family.
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection() filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver( backend_resolver = build_backend_resolver(
filesystem_selection, filesystem_selection,
@ -287,23 +338,39 @@ async def create_surfsense_deep_agent(
else None, else None,
) )
# Discover available connectors and document types for this search space # Discover available connectors and document types for this search space.
#
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
# SQLAlchemy explicitly forbids concurrent operations on the same session
# ("This session is provisioning a new connection; concurrent operations
# are not permitted on the same session"). The Phase 1.4 in-process TTL
# cache in ``connector_service`` already collapses the warm path to a
# near-zero pair of dict lookups, so sequential awaits cost nothing in
# the common case while remaining correct on cold cache misses.
available_connectors: list[str] | None = None available_connectors: list[str] | None = None
available_document_types: list[str] | None = None available_document_types: list[str] | None = None
_t0 = time.perf_counter() _t0 = time.perf_counter()
try: try:
connector_types = await connector_service.get_available_connectors( try:
search_space_id connector_types_result = await connector_service.get_available_connectors(
) search_space_id
if connector_types: )
available_connectors = _map_connectors_to_searchable_types(connector_types) if connector_types_result:
available_connectors = _map_connectors_to_searchable_types(
connector_types_result
)
except Exception as e:
logging.warning("Failed to discover available connectors: %s", e)
available_document_types = await connector_service.get_available_document_types( try:
search_space_id available_document_types = (
) await connector_service.get_available_document_types(search_space_id)
)
except Exception as e: except Exception as e:
logging.warning("Failed to discover available document types: %s", e)
except Exception as e: # pragma: no cover - defensive outer guard
logging.warning(f"Failed to discover available connectors/document types: {e}") logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info( _perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs", "[create_agent] Connector/doc-type discovery in %.3fs",
@ -398,6 +465,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names, enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names, disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools, mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
) )
else: else:
system_prompt = build_surfsense_system_prompt( system_prompt = build_surfsense_system_prompt(
@ -405,6 +473,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names, enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names, disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools, mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
) )
_perf_log.info( _perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@ -424,29 +493,77 @@ async def create_surfsense_deep_agent(
# entire middleware build + main-graph compile into a single # entire middleware build + main-graph compile into a single
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
# event loop stays responsive. # event loop stays responsive.
#
# PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
# on every per-request value that any middleware in the stack closes
# over in ``__init__`` — drop one and you risk leaking state across
# threads. Hits collapse this whole block to a microsecond lookup;
# misses pay the original CPU cost AND populate the cache.
config_id = agent_config.config_id if agent_config is not None else None
async def _build_agent() -> Any:
return await asyncio.to_thread(
_build_compiled_agent_blocking,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
# ``mentioned_document_ids`` is consumed by
# ``KnowledgePriorityMiddleware`` per turn via
# ``runtime.context`` (Phase 1.5). We still pass the
# caller-provided list here for the legacy fallback path
# (cache disabled / context not propagated) — the middleware
# drains its own copy after the first read so a cached graph
# never replays stale mentions.
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=_max_input_tokens,
flags=_flags,
checkpointer=checkpointer,
)
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent = await asyncio.to_thread( if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
_build_compiled_agent_blocking, # Cache key components — order matters only for human readability;
llm=llm, # the resulting hash is what's stored. Every component must
tools=tools, # rotate on a real shape change AND stay stable across identical
final_system_prompt=final_system_prompt, # invocations.
backend_resolver=backend_resolver, cache_key = stable_hash(
filesystem_mode=filesystem_selection.mode, "v1", # schema version of the key — bump if components change
search_space_id=search_space_id, config_id,
user_id=user_id, thread_id,
thread_id=thread_id, user_id,
visibility=visibility, search_space_id,
anon_session_id=anon_session_id, visibility,
available_connectors=available_connectors, filesystem_selection.mode,
available_document_types=available_document_types, anon_session_id,
mentioned_document_ids=mentioned_document_ids, tools_signature(
max_input_tokens=_max_input_tokens, tools,
flags=_flags, available_connectors=available_connectors,
checkpointer=checkpointer, available_document_types=available_document_types,
) ),
flags_signature(_flags),
system_prompt_hash(final_system_prompt),
_max_input_tokens,
# ``mentioned_document_ids`` deliberately omitted — middleware
# reads it from ``runtime.context`` (Phase 1.5).
)
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
else:
agent = await _build_agent()
_perf_log.info( _perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs", "[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
"on"
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
else "off",
) )
_perf_log.info( _perf_log.info(
@ -568,7 +685,6 @@ def _build_compiled_agent_blocking(
), ),
create_surfsense_compaction_middleware(llm, StateBackend), create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(), PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
] ]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
@ -998,6 +1114,14 @@ def _build_compiled_agent_blocking(
noop_mw, noop_mw,
retry_mw, retry_mw,
fallback_mw, fallback_mw,
# Coalesce a multi-text-block system message into one block
# immediately before the model call. Sits innermost on the
# system-message-mutation chain so it observes every appender
# (todo / filesystem / skills / subagents …) and prevents
# OpenRouter→Anthropic from redistributing ``cache_control``
# across N blocks and tripping Anthropic's 4-breakpoint cap.
# See ``middleware/flatten_system.py`` for full rationale.
FlattenSystemMessageMiddleware(),
# Tool-call repair must run after model emits but before # Tool-call repair must run after model emits but before
# permission / dedup / doom-loop interpret the calls. # permission / dedup / doom-loop interpret the calls.
repair_mw, repair_mw,
@ -1010,12 +1134,12 @@ def _build_compiled_agent_blocking(
action_log_mw, action_log_mw,
PatchToolCallsMiddleware(), PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)), DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
# Plugin slot — sits just before AnthropicCache so plugin-side # Plugin slot — sits at the tail so plugin-side transforms see the
# transforms see the final tool result and run before any # final tool result. Prompt caching is now applied at LLM build time
# caching heuristics. Multiple plugins in declared order; loader # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
# filtered by the admin allowlist already. # caching middleware is needed here. Multiple plugins run in declared
# order; loader filtered by the admin allowlist already.
*plugin_middlewares, *plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
] ]
deepagent_middleware = [m for m in deepagent_middleware if m is not None] deepagent_middleware = [m for m in deepagent_middleware if m is not None]

View file

@ -1,10 +1,25 @@
""" """
Context schema definitions for SurfSense agents. Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent. This module defines the per-invocation context object passed to the SurfSense
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
The agent's compiled graph is the same across invocations (and cached by
``agent_cache``), so anything that varies per turn the user mentions a
specific document, the front-end issues a unique ``request_id``, etc.
MUST live on this context object instead of being captured into a
middleware ``__init__`` closure. Middlewares read fields back via
``runtime.context.<field>``; tools read them via ``runtime.context``.
This object is read inside both ``KnowledgePriorityMiddleware`` (for
``mentioned_document_ids``) and any future middleware that needs
per-request state without invalidating the compiled-agent cache.
""" """
from typing import NotRequired, TypedDict from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
class FileOperationContractState(TypedDict): class FileOperationContractState(TypedDict):
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
turn_id: str turn_id: str
class SurfSenseContextSchema(TypedDict): @dataclass
class SurfSenseContextSchema:
""" """
Custom state schema for the SurfSense deep agent. Per-invocation context for the SurfSense deep agent.
This extends the default agent state with custom fields. Defaults are chosen so the dataclass can be safely default-constructed
The default state already includes: (LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
- messages: Conversation history context is supplied see ``langgraph.runtime.Runtime``). All fields
- todos: Task list from TodoListMiddleware are optional; consumers must None-check before reading.
- files: Virtual filesystem from FilesystemMiddleware
We're adding fields needed for knowledge base search: Phase 1.5 fields:
- search_space_id: The user's search space ID search_space_id: Search space the request is scoped to.
- db_session: Database session (injected at runtime) mentioned_document_ids: KB documents the user @-mentioned this turn.
- connector_service: Connector service instance (injected at runtime) Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
file_operation_contract: One-shot file operation contract emitted
by ``FileIntentMiddleware`` for the upcoming turn.
turn_id / request_id: Correlation IDs surfaced by the streaming
task; populated for telemetry.
Phase 2 will extend with: thread_id, user_id, visibility,
filesystem_mode, anon_session_id, available_connectors,
available_document_types, created_by_id (everything currently captured
by middleware ``__init__`` closures).
""" """
search_space_id: int search_space_id: int | None = None
file_operation_contract: NotRequired[FileOperationContractState] mentioned_document_ids: list[int] = field(default_factory=list)
turn_id: NotRequired[str] file_operation_contract: FileOperationContractState | None = None
request_id: NotRequired[str] turn_id: str | None = None
# These are runtime-injected and won't be serialized request_id: str | None = None
# db_session and connector_service are passed when invoking the agent

View file

@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
These flags gate the newer agent middleware (some ported from OpenCode, These flags gate the newer agent middleware (some ported from OpenCode,
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
SurfSense-native). They follow a "default-OFF for risky things, SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
default-ON for safe upgrades, master kill-switch for everything new" model. image updates work even when older installs do not have newly introduced
environment variables. Risky/experimental integrations stay default OFF,
and the master kill-switch can still disable everything new.
All new middleware checks its flag at agent build time. If the master All new middleware checks its flag at agent build time. If the master
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
Examples Examples
-------- --------
Local development (recommended for trying everything except doom-loop / selector): Defaults:
SURFSENSE_ENABLE_CONTEXT_EDITING=true SURFSENSE_ENABLE_CONTEXT_EDITING=true
SURFSENSE_ENABLE_COMPACTION_V2=true SURFSENSE_ENABLE_COMPACTION_V2=true
SURFSENSE_ENABLE_RETRY_AFTER=true SURFSENSE_ENABLE_RETRY_AFTER=true
SURFSENSE_ENABLE_MODEL_FALLBACK=false
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else): Master kill-switch (overrides everything else):
@ -60,32 +65,28 @@ class AgentFeatureFlags:
disable_new_agent_stack: bool = False disable_new_agent_stack: bool = False
# Agent quality — context budget, retry/limits, name-repair, doom-loop # Agent quality — context budget, retry/limits, name-repair, doom-loop
enable_context_editing: bool = False enable_context_editing: bool = True
enable_compaction_v2: bool = False enable_compaction_v2: bool = True
enable_retry_after: bool = False enable_retry_after: bool = True
enable_model_fallback: bool = False enable_model_fallback: bool = False
enable_model_call_limit: bool = False enable_model_call_limit: bool = True
enable_tool_call_limit: bool = False enable_tool_call_limit: bool = True
enable_tool_call_repair: bool = False enable_tool_call_repair: bool = True
enable_doom_loop: bool = ( enable_doom_loop: bool = True
False # Default OFF until UI handles permission='doom_loop'
)
# Safety — permissions, concurrency, tool-set narrowing # Safety — permissions, concurrency, tool-set narrowing
enable_permission: bool = False # Default OFF for first deploy enable_permission: bool = True
enable_busy_mutex: bool = False enable_busy_mutex: bool = True
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
# Skills + subagents # Skills + subagents
enable_skills: bool = False enable_skills: bool = True
enable_specialized_subagents: bool = False enable_specialized_subagents: bool = True
enable_kb_planner_runnable: bool = False enable_kb_planner_runnable: bool = True
# Snapshot / revert # Snapshot / revert
enable_action_log: bool = False enable_action_log: bool = True
enable_revert_route: bool = ( enable_revert_route: bool = True
False # Backend ships before UI; route returns 503 until this flips
)
# Streaming parity v2 — opt in to LangChain's structured # Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input # ``AIMessageChunk`` content (typed reasoning blocks, tool-input
@ -94,7 +95,7 @@ class AgentFeatureFlags:
# text path and the synthetic ``call_<run_id>`` tool-call id (no # text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136 # ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible. # ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = False enable_stream_parity_v2: bool = True
# Plugins # Plugins
enable_plugin_loader: bool = False enable_plugin_loader: bool = False
@ -102,6 +103,41 @@ class AgentFeatureFlags:
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
enable_otel: bool = False enable_otel: bool = False
# Performance — compiled-agent cache (Phase 1 + Phase 2).
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
# graph if the cache key matches (LLM config + thread + tool surface +
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
# wall clock from ~4-5s to <50µs on cache hits.
#
# SAFETY (Phase 2 unblocked this default-on):
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
# short-lived ``AsyncSession`` instances per call via
# :data:`async_session_maker`. The factory still accepts ``db_session``
# for registry compatibility but ``del``'s it immediately — see any
# of those files' factory docstrings for the rationale. The ``llm``
# closure is per-(provider, model, config_id) which is already in
# the cache key, so the LLM is safe to share across cached hits of
# the same key. The KB priority middleware reads
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
# not its constructor closure, so the same compiled agent serves
# turns with different mention lists correctly.
#
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
# environment if a regression surfaces. The path is exercised by
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
enable_agent_cache: bool = True
# Phase 1 (deferred — measure first): pre-build & share the
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
# misses. Only helps when the outer cache MISSES (cache hits already
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
# until we have data showing cold misses are frequent enough to
# justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False
@classmethod @classmethod
def from_env(cls) -> AgentFeatureFlags: def from_env(cls) -> AgentFeatureFlags:
"""Read flags from environment. """Read flags from environment.
@ -115,48 +151,76 @@ class AgentFeatureFlags:
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
"middleware is forced OFF for this build." "middleware is forced OFF for this build."
) )
return cls(disable_new_agent_stack=True) return cls(
disable_new_agent_stack=True,
enable_context_editing=False,
enable_compaction_v2=False,
enable_retry_after=False,
enable_model_fallback=False,
enable_model_call_limit=False,
enable_tool_call_limit=False,
enable_tool_call_repair=False,
enable_doom_loop=False,
enable_permission=False,
enable_busy_mutex=False,
enable_llm_tool_selector=False,
enable_skills=False,
enable_specialized_subagents=False,
enable_kb_planner_runnable=False,
enable_action_log=False,
enable_revert_route=False,
enable_stream_parity_v2=False,
enable_plugin_loader=False,
enable_otel=False,
enable_agent_cache=False,
enable_agent_cache_share_gp_subagent=False,
)
return cls( return cls(
disable_new_agent_stack=False, disable_new_agent_stack=False,
# Agent quality # Agent quality
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
enable_model_call_limit=_env_bool( enable_model_call_limit=_env_bool(
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
), ),
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
enable_tool_call_repair=_env_bool( enable_tool_call_repair=_env_bool(
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
), ),
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
# Safety # Safety
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
enable_llm_tool_selector=_env_bool( enable_llm_tool_selector=_env_bool(
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
), ),
# Skills + subagents # Skills + subagents
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
enable_specialized_subagents=_env_bool( enable_specialized_subagents=_env_bool(
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
), ),
enable_kb_planner_runnable=_env_bool( enable_kb_planner_runnable=_env_bool(
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
), ),
# Snapshot / revert # Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2 # Streaming parity v2
enable_stream_parity_v2=_env_bool( enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False "SURFSENSE_ENABLE_STREAM_PARITY_V2", True
), ),
# Plugins # Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability # Observability
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
# Performance
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
enable_agent_cache_share_gp_subagent=_env_bool(
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
),
) )
def any_new_middleware_enabled(self) -> bool: def any_new_middleware_enabled(self) -> bool:

View file

@ -27,6 +27,7 @@ from litellm import get_model_info
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.services.llm_router_service import ( from app.services.llm_router_service import (
AUTO_MODE_ID, AUTO_MODE_ID,
ChatLiteLLMRouter, ChatLiteLLMRouter,
@ -89,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk yield chunk
# Provider mapping for LiteLLM model string construction # Provider mapping for LiteLLM model string construction.
PROVIDER_MAP = { #
"OPENAI": "openai", # Single source of truth lives in
"ANTHROPIC": "anthropic", # :mod:`app.services.provider_capabilities` so the YAML loader (which
"GROQ": "groq", # runs during ``app.config`` class-body init) can resolve provider
"COHERE": "cohere", # prefixes without dragging the agent / tools tree into module load
"GOOGLE": "gemini", # order. Re-exported here under the historical ``PROVIDER_MAP`` name
"OLLAMA": "ollama_chat", # so existing callers (``llm_router_service``, ``image_gen_router_service``,
"MISTRAL": "mistral", # tests) keep working unchanged.
"AZURE_OPENAI": "azure", from app.services.provider_capabilities import ( # noqa: E402
"OPENROUTER": "openrouter", _PROVIDER_PREFIX_MAP as PROVIDER_MAP,
"XAI": "xai", )
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@ -177,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog.
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
# which prefers OpenRouter's ``architecture.input_modalities`` and
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
supports_image_input: bool = True
@classmethod @classmethod
def from_auto_mode(cls) -> "AgentConfig": def from_auto_mode(cls) -> "AgentConfig":
""" """
@ -202,6 +191,12 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually
# contains at least one vision-capable deployment; the router
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
supports_image_input=True,
) )
@classmethod @classmethod
@ -215,10 +210,24 @@ class AgentConfig:
Returns: Returns:
AgentConfig instance AgentConfig instance
""" """
return cls( # Lazy import to avoid pulling provider_capabilities (and its
provider=config.provider.value # transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
config.provider.value
if hasattr(config.provider, "value") if hasattr(config.provider, "value")
else str(config.provider), else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
return cls(
provider=provider_value,
model_name=config.model_name, model_name=config.model_name,
api_key=config.api_key, api_key=config.api_key,
api_base=config.api_base, api_base=config.api_base,
@ -234,6 +243,16 @@ class AgentConfig:
is_premium=False, is_premium=False,
anonymous_enabled=False, anonymous_enabled=False,
quota_reserve_tokens=None, quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we
# ask LiteLLM (default-allow on unknown). The streaming
# safety net still blocks if the model is *explicitly*
# marked text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
),
) )
@classmethod @classmethod
@ -252,15 +271,46 @@ class AgentConfig:
Returns: Returns:
AgentConfig instance AgentConfig instance
""" """
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string # Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "") system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
# Explicit YAML override wins; otherwise derive from LiteLLM /
# OpenRouter modalities. The YAML loader already populates this
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
)
return cls( return cls(
provider=yaml_config.get("provider", "").upper(), provider=provider,
model_name=yaml_config.get("model_name", ""), model_name=model_name,
api_key=yaml_config.get("api_key", ""), api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"), api_base=yaml_config.get("api_base"),
custom_provider=yaml_config.get("custom_provider"), custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"), litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility) # Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None, system_instructions=system_instructions if system_instructions else None,
@ -275,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium", is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False), anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
supports_image_input=supports_image_input,
) )
@ -494,6 +545,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
llm = SanitizedChatLiteLLM(**litellm_kwargs) llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string) _attach_model_profile(llm, model_string)
# Configure LiteLLM-native prompt caching (cache_control_injection_points
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
# ``agent_config=None`` here — the YAML path doesn't have provider intent
# in a structured form, so we set only the universal injection points.
apply_litellm_prompt_caching(llm)
return llm return llm
@ -518,7 +574,16 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized") print("Error: Auto mode requested but LLM Router not initialized")
return None return None
try: try:
return get_auto_mode_llm() router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal cache_control_injection_points only — auto-mode
# fans out across providers, so OpenAI-only kwargs (e.g.
# ``prompt_cache_key``) are left off here. ``drop_params``
# would strip them at the provider boundary anyway, but
# there's no point setting them when we don't know the
# destination.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e: except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}") print(f"Error creating ChatLiteLLMRouter: {e}")
return None return None
@ -549,4 +614,9 @@ def create_chat_litellm_from_agent_config(
llm = SanitizedChatLiteLLM(**litellm_kwargs) llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string) _attach_model_profile(llm, model_string)
# Build-time prompt caching: sets ``cache_control_injection_points`` for
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
# Per-thread ``prompt_cache_key`` is layered on later in
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm return llm

View file

@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import ( from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware, SurfSenseFilesystemMiddleware,
) )
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
)
from app.agents.new_chat.middleware.kb_persistence import ( from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware, KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state, commit_staged_filesystem_state,
@ -61,6 +64,7 @@ __all__ = [
"DedupHITLToolCallsMiddleware", "DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware", "DoomLoopMiddleware",
"FileIntentMiddleware", "FileIntentMiddleware",
"FlattenSystemMessageMiddleware",
"KnowledgeBasePersistenceMiddleware", "KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware", "KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware", "KnowledgePriorityMiddleware",

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time
import weakref import weakref
from typing import Any from typing import Any
@ -58,6 +59,11 @@ class _ThreadLockManager:
weakref.WeakValueDictionary() weakref.WeakValueDictionary()
) )
self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
# Monotonic per-thread epoch used to prevent stale middleware
# teardown from releasing a newer turn's lock.
self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock: def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id) lock = self._locks.get(thread_id)
@ -76,14 +82,57 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool: def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id) event = self._cancel_events.get(thread_id)
if event is None: if event is None:
return False event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set() event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None: def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id) event = self._cancel_events.get(thread_id)
if event is not None: if event is not None:
event.clear() event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def bump_turn_epoch(self, thread_id: str) -> int:
epoch = self._turn_epoch.get(thread_id, 0) + 1
self._turn_epoch[thread_id] = epoch
return epoch
def current_turn_epoch(self, thread_id: str) -> int:
return self._turn_epoch.get(thread_id, 0)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
# Invalidate any in-flight middleware holder first. This guarantees a
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
# retry that already acquired the lock for the same thread.
self.bump_turn_epoch(thread_id)
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
def release(self, thread_id: str) -> bool: def release(self, thread_id: str) -> bool:
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``. """Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool: def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found.""" """Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id) return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None: def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns).""" """Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id) manager.reset(thread_id)
def release_lock(thread_id: str) -> bool: def end_turn(thread_id: str) -> None:
"""Force-release the per-thread busy lock; safe to call when not held.""" """Force end-of-turn cleanup for lock + cancel state."""
return manager.release(thread_id) manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__() super().__init__()
self._require_thread_id = require_thread_id self._require_thread_id = require_thread_id
self.tools = [] self.tools = []
# Per-call locks owned by this middleware. We track them as # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
# an instance attribute so ``aafter_agent`` knows which lock # only releases when its epoch still matches the manager's current
# to release. # epoch for the thread, preventing stale unlock races.
self._held_locks: dict[str, asyncio.Lock] = {} self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod @staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None: def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked(): if lock.locked():
raise BusyError(request_id=thread_id) raise BusyError(request_id=thread_id)
await lock.acquire() await lock.acquire()
self._held_locks[thread_id] = lock epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh # Reset the cancel event so this turn starts fresh
reset_cancel(thread_id) reset_cancel(thread_id)
return None return None
@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime) thread_id = self._thread_id(runtime)
if thread_id is None: if thread_id is None:
return None return None
lock = self._held_locks.pop(thread_id, None) held = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked(): if held is None:
return None
lock, held_epoch = held
if held_epoch != manager.current_turn_epoch(thread_id):
# Stale teardown from an older attempt (e.g. runtime-recovery path
# already advanced epoch). Do not touch current lock/cancel state.
return None
if lock.locked():
lock.release() lock.release()
# Always clear cancel event between turns so a stale signal # Always clear cancel event between turns so a stale signal
# doesn't leak into the next request. # doesn't leak into the next request.
@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [ __all__ = [
"BusyMutexMiddleware", "BusyMutexMiddleware",
"end_turn",
"get_cancel_event", "get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager", "manager",
"release_lock",
"request_cancel", "request_cancel",
"reset_cancel", "reset_cancel",
] ]

View file

@ -0,0 +1,233 @@
r"""Coalesce multi-block system messages into a single text block.
Several middlewares in our deepagent stack each call
``append_to_system_message`` on the way down to the model
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
``SkillsMiddleware``, ``SubAgentMiddleware`` ). By the time the
request reaches the LLM, the system message has 5+ separate text blocks.
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
request**, and we configure 2 injection points
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
the prepended ``request.system_message``, this middleware is the
defensive partner: it guarantees that "the system block" is *one*
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
OpenRouterAnthropic transformer can never multiply our budget into
several breakpoints by spreading ``cache_control`` across multiple
text blocks of a multi-block system content.
Without flattening we used to see::
OpenrouterException - {"error":{"message":"Provider returned error",
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
cache_control may be provided. Found 5."}}}
(Same error class documented in
https://github.com/BerriAI/litellm/issues/15696 and
https://github.com/BerriAI/litellm/issues/20485 the litellm-side fix
in PR #15395 covers the litellm transformer but does not protect us
when the OpenRouter SaaS itself does the redistribution.)
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
the first injection point from ``role: system`` to ``index: 0``)
neutralises the *primary* cause of the same 400 multiple
``SystemMessage``\ s injected by ``before_agent`` middlewares
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
turns, each tagged with ``cache_control`` by the ``role: system``
matcher. This middleware remains useful as defence-in-depth against
the multi-block redistribution path.
Placement: innermost on the system-message-mutation chain, after every
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
summarization, but before ``noop``/``retry``/``fallback`` so each retry
attempt sees a flattened payload. See ``chat_deepagent.py``.
Idempotent: a string-content system message is left untouched. A list
that contains anything other than plain text blocks (e.g. an image) is
also left untouched those are rare on system messages and we'd lose
the non-text payload by joining.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import SystemMessage
logger = logging.getLogger(__name__)
def _flatten_text_blocks(content: list[Any]) -> str | None:
"""Return joined text if every block is a plain ``{"type": "text"}``.
Returns ``None`` when the list contains anything that isn't a text
block we can safely concatenate (image, audio, file, non-standard
blocks, dicts with extra non-cache_control fields). The caller
leaves the original content untouched in that case rather than
silently dropping payload.
``cache_control`` on individual blocks is intentionally discarded
the whole point of flattening is to let LiteLLM's
``cache_control_injection_points`` re-place a single breakpoint on
the resulting one-block system content.
"""
chunks: list[str] = []
for block in content:
if isinstance(block, str):
chunks.append(block)
continue
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
chunks.append(text)
return "\n\n".join(chunks)
def _flattened_request(
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT] | None:
"""Return a request with system_message flattened, or ``None`` for no-op."""
sys_msg = request.system_message
if sys_msg is None:
return None
content = sys_msg.content
if not isinstance(content, list) or len(content) <= 1:
return None
flattened = _flatten_text_blocks(content)
if flattened is None:
return None
new_sys = SystemMessage(
content=flattened,
additional_kwargs=dict(sys_msg.additional_kwargs),
response_metadata=dict(sys_msg.response_metadata),
)
if sys_msg.id is not None:
new_sys.id = sys_msg.id
return request.override(system_message=new_sys)
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
"""One-line dump of cache_control-relevant request shape.
Temporary diagnostic to prove where the ``Found N`` cache_control
breakpoints are coming from when Anthropic 400s. Removed once the
root cause is confirmed and a fix is in place.
"""
sys_msg = request.system_message
if sys_msg is None:
sys_shape = "none"
elif isinstance(sys_msg.content, str):
sys_shape = f"str(len={len(sys_msg.content)})"
elif isinstance(sys_msg.content, list):
sys_shape = f"list(blocks={len(sys_msg.content)})"
else:
sys_shape = f"other({type(sys_msg.content).__name__})"
role_hist: list[str] = []
multi_block_msgs = 0
msgs_with_cc = 0
sys_msgs_in_history = 0
for m in request.messages:
mtype = getattr(m, "type", type(m).__name__)
role_hist.append(mtype)
if isinstance(m, SystemMessage):
sys_msgs_in_history += 1
c = getattr(m, "content", None)
if isinstance(c, list):
multi_block_msgs += 1
for blk in c:
if isinstance(blk, dict) and "cache_control" in blk:
msgs_with_cc += 1
break
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
msgs_with_cc += 1
tools = request.tools or []
tools_with_cc = 0
for t in tools:
if isinstance(t, dict) and (
"cache_control" in t or "cache_control" in t.get("function", {})
):
tools_with_cc += 1
return (
f"sys={sys_shape} msgs={len(request.messages)} "
f"sys_msgs_in_history={sys_msgs_in_history} "
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
f"roles={role_hist[-8:]}"
)
class FlattenSystemMessageMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Collapse a multi-text-block system message to a single string.
Sits innermost on the system-message-mutation chain so it observes
every middleware's contribution. Has no other side effect — the
body of every block is preserved, just joined with ``"\\n\\n"``.
"""
def __init__(self) -> None:
super().__init__()
self.tools = []
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return handler(flattened)
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
flattened = _flattened_request(request)
if flattened is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"[flatten_system] collapsed %d system blocks to one",
len(request.system_message.content), # type: ignore[arg-type, union-attr]
)
return await handler(flattened)
return await handler(request)
__all__ = [
"FlattenSystemMessageMiddleware",
"_flatten_text_blocks",
"_flattened_request",
]

View file

@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState, state: AgentState,
runtime: Runtime[Any], runtime: Runtime[Any],
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD: if self.filesystem_mode != FilesystemMode.CLOUD:
return None return None
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if anon_doc: if anon_doc:
return self._anon_priority(state, anon_doc) return self._anon_priority(state, anon_doc)
return await self._authenticated_priority(state, messages, user_text) return await self._authenticated_priority(state, messages, user_text, runtime)
def _anon_priority( def _anon_priority(
self, self,
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
state: AgentState, state: AgentState,
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
user_text: str, user_text: str,
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time() t0 = asyncio.get_event_loop().time()
( (
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text, user_text=user_text,
) )
# Per-turn ``mentioned_document_ids`` flow:
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
# on every ``astream_events`` call, so this list is naturally
# scoped to the current turn. Allows cross-turn graph reuse via
# ``agent_cache``.
# 2. Legacy fallback (cache disabled / context not propagated): the
# constructor-injected ``self.mentioned_document_ids`` list. We
# drain it after the first read so a cached graph (no Phase 1.5
# wiring) doesn't keep replaying the same mentions on every
# turn.
#
# CRITICAL: distinguish "context absent" (legacy caller, no field at
# all) from "context provided but empty" (turn with no mentions).
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
# Python, so a naive ``if ctx_mentions:`` would fall through to the
# legacy closure on every no-mention follow-up turn — replaying the
# mentions baked in by turn 1's cache-miss build. Always drain the
# closure once the runtime path has fired so a cached middleware
# instance can never resurrect stale state.
mention_ids: list[int] = []
ctx = getattr(runtime, "context", None) if runtime is not None else None
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
if ctx_mentions is not None:
# Runtime path is authoritative — even an empty list means
# "this turn has no mentions", NOT "look at the closure".
mention_ids = list(ctx_mentions)
if self.mentioned_document_ids:
self.mentioned_document_ids = []
elif self.mentioned_document_ids:
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
mentioned_results: list[dict[str, Any]] = [] mentioned_results: list[dict[str, Any]] = []
if self.mentioned_document_ids: if mention_ids:
mentioned_results = await fetch_mentioned_documents( mentioned_results = await fetch_mentioned_documents(
document_ids=self.mentioned_document_ids, document_ids=mention_ids,
search_space_id=self.search_space_id, search_space_id=self.search_space_id,
) )
self.mentioned_document_ids = []
if is_recency: if is_recency:
doc_types = _resolve_search_types( doc_types = _resolve_search_types(

View file

@ -0,0 +1,188 @@
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
gate always failed) with LiteLLM's universal caching mechanism.
Coverage:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
performs automatically when ``cache_control_injection_points`` is set):
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
``deepseek/``, ``xai/`` these caches automatically for prompts 1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
We inject **two** breakpoints per request:
- ``index: 0`` pins the SurfSense system prompt at the head of the
request (provider variant, citation rules, tool catalog, KB tree,
skills metadata). The langchain agent factory always prepends
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouterAnthropic.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend.
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAIBedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from langchain_core.language_models import BaseChatModel
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
logger = logging.getLogger(__name__)
# Two-breakpoint policy: head-of-request + latest message. See module
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
# anonymous-doc) insert ``SystemMessage`` instances into
# ``state["messages"]`` that accumulate across turns. With
# ``role: system`` the LiteLLM hook would tag *every* one of them with
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
# always targets the langchain-prepended ``request.system_message``
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
# block), giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": 0},
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
# MINIMAX), so we can't infer family from the litellm prefix alone.
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
return type(llm).__name__ == "ChatLiteLLMRouter"
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets an OpenAI-style prompt-cache surface.
Strict only returns True when the user explicitly chose OPENAI,
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
``YAMLConfig``. Auto-mode and custom providers return False because
we can't statically know the destination.
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
model. Returns ``None`` if the LLM type doesn't expose a writable
``model_kwargs`` attribute (caller should treat as no-op).
"""
model_kwargs = getattr(llm, "model_kwargs", None)
if isinstance(model_kwargs, dict):
return model_kwargs
try:
llm.model_kwargs = {} # type: ignore[attr-defined]
except Exception:
return None
refreshed = getattr(llm, "model_kwargs", None)
return refreshed if isinstance(refreshed, dict) else None
def apply_litellm_prompt_caching(
llm: BaseChatModel,
*,
agent_config: AgentConfig | None = None,
thread_id: int | None = None,
) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent values already present in ``llm.model_kwargs`` (e.g. from
``agent_config.litellm_params`` overrides) are preserved. Mutates
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
"""
model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None:
logger.debug(
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
type(llm).__name__,
)
return
if "cache_control_injection_points" not in model_kwargs:
model_kwargs["cache_control_injection_points"] = [
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-family extras only when we statically know the destination is
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
# so we can't safely set OpenAI-only kwargs there (drop_params would
# strip them but it's wasteful to set them in the first place).
if _is_router_llm(llm):
return
if not _is_openai_family_config(agent_config):
return
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if "prompt_cache_retention" not in model_kwargs:
model_kwargs["prompt_cache_retention"] = "24h"

View file

@ -27,14 +27,9 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset( NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset(
{ {
# Exact tool names from shared deny patterns. # Exact tool names from shared deny patterns.
*{ *{name for name in WRITE_TOOL_DENY_PATTERNS if "*" not in name},
name
for name in WRITE_TOOL_DENY_PATTERNS
if "*" not in name
},
# Additional non-provider state mutation controls. # Additional non-provider state mutation controls.
"write_todos", "write_todos",
"task", "task",
} }
) )

View file

@ -112,10 +112,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
Rule(permission=name, pattern="*", action="deny") Rule(permission=name, pattern="*", action="deny")
for name in NON_PROVIDER_STATE_MUTATION_DENY for name in NON_PROVIDER_STATE_MUTATION_DENY
) )
rules.extend( rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
Rule(permission=name, pattern="*", action="ask")
for name in ask_tools
)
return PermissionMiddleware( return PermissionMiddleware(
rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")] rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")]
) )
@ -163,4 +160,3 @@ def build_linear_specialist_subagent(
if model is not None: if model is not None:
spec["model"] = model spec["model"] = model
return spec # type: ignore[return-value] return spec # type: ignore[return-value]

View file

@ -119,10 +119,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
Rule(permission=name, pattern="*", action="deny") Rule(permission=name, pattern="*", action="deny")
for name in NON_PROVIDER_STATE_MUTATION_DENY for name in NON_PROVIDER_STATE_MUTATION_DENY
) )
rules.extend( rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
Rule(permission=name, pattern="*", action="ask")
for name in ask_tools
)
return PermissionMiddleware( return PermissionMiddleware(
rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")] rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")]
) )
@ -171,4 +168,3 @@ def build_slack_specialist_subagent(
if model is not None: if model is not None:
spec["model"] = model spec["model"] = model
return spec # type: ignore[return-value] return spec # type: ignore[return-value]

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the create_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_confluence_page( async def create_confluence_page(
title: str, title: str,
@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
""" """
logger.info(f"create_confluence_page called: title='{title}'") logger.info(f"create_confluence_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
metadata_service = ConfluenceToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = ConfluenceToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
)
if "error" in context: if "error" in context:
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", []) accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts): if accounts and all(a.get("auth_expired") for a in accounts):
return { return {
"status": "auth_error", "status": "auth_error",
"message": "All connected Confluence accounts need re-authentication.", "message": "All connected Confluence accounts need re-authentication.",
"connector_type": "confluence", "connector_type": "confluence",
} }
result = request_approval( result = request_approval(
action_type="confluence_page_creation", action_type="confluence_page_creation",
tool_name="create_confluence_page", tool_name="create_confluence_page",
params={ params={
"title": title, "title": title,
"content": content, "content": content,
"space_id": space_id, "space_id": space_id,
"connector_id": connector_id, "connector_id": connector_id,
}, },
context=context, context=context,
) )
if result.rejected: if result.rejected:
return { return {
"status": "rejected", "status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.", "message": "User declined. Do not retry or suggest alternatives.",
} }
final_title = result.params.get("title", title) final_title = result.params.get("title", title)
final_content = result.params.get("content", content) or "" final_content = result.params.get("content", content) or ""
final_space_id = result.params.get("space_id", space_id) final_space_id = result.params.get("space_id", space_id)
final_connector_id = result.params.get("connector_id", connector_id) final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip(): if not final_title or not final_title.strip():
return {"status": "error", "message": "Page title cannot be empty."} return {"status": "error", "message": "Page title cannot be empty."}
if not final_space_id: if not final_space_id:
return {"status": "error", "message": "A space must be selected."} return {"status": "error", "message": "A space must be selected."}
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id actual_connector_id = final_connector_id
if actual_connector_id is None: if actual_connector_id is None:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR, == SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
) )
) connector = result.scalars().first()
connector = result.scalars().first() if not connector:
if not connector: return {
return { "status": "error",
"status": "error", "message": "No Confluence connector found.",
"message": "No Confluence connector found.", }
} actual_connector_id = connector.id
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
api_result = await client.create_page(
space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else: else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." result = await db_session.execute(
except Exception as kb_err: select(SearchSourceConnector).filter(
logger.warning(f"KB sync after create failed: {kb_err}") SearchSourceConnector.id == actual_connector_id,
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
return { try:
"status": "success", client = ConfluenceHistoryConnector(
"page_id": page_id, session=db_session, connector_id=actual_connector_id
"page_url": page_url, )
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}", api_result = await client.create_page(
} space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"page_id": page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the delete_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_confluence_page( async def delete_confluence_page(
page_title_or_id: str, page_title_or_id: str,
@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'" f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
metadata_service = ConfluenceToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_deletion_context( metadata_service = ConfluenceToolMetadataService(db_session)
search_space_id, user_id, page_title_or_id context = await metadata_service.get_deletion_context(
) search_space_id, user_id, page_title_or_id
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_deletion",
tool_name="delete_confluence_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
) )
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try: if "error" in context:
client = ConfluenceHistoryConnector( error_msg = context["error"]
session=db_session, connector_id=final_connector_id if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_deletion",
tool_name="delete_confluence_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
) )
await client.delete_page(final_page_id)
await client.close() if result.rejected:
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return { return {
"status": "insufficient_permissions", "status": "rejected",
"connector_id": final_connector_id, "message": "User declined. Do not retry or suggest alternatives.",
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
deleted_from_kb = False final_page_id = result.params.get("page_id", page_id)
if final_delete_from_kb and document_id: final_connector_id = result.params.get(
try: "connector_id", connector_id_from_context
from app.db import Document )
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
doc_result = await db_session.execute( from sqlalchemy.future import select
select(Document).filter(Document.id == document_id)
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
) )
document = doc_result.scalars().first() )
if document: connector = result.scalars().first()
await db_session.delete(document) if not connector:
await db_session.commit() return {
deleted_from_kb = True "status": "error",
except Exception as e: "message": "Selected Confluence connector is invalid.",
logger.error(f"Failed to delete document from KB: {e}") }
await db_session.rollback()
message = f"Confluence page '{page_title}' deleted successfully." try:
if deleted_from_kb: client = ConfluenceHistoryConnector(
message += " Also removed from the knowledge base." session=db_session, connector_id=final_connector_id
)
await client.delete_page(final_page_id)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
return { deleted_from_kb = False
"status": "success", if final_delete_from_kb and document_id:
"page_id": final_page_id, try:
"deleted_from_kb": deleted_from_kb, from app.db import Document
"message": message,
} doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Confluence page '{page_title}' deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"page_id": final_page_id,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""
Factory function to create the update_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_confluence_page( async def update_confluence_page(
page_title_or_id: str, page_title_or_id: str,
@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'" f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Confluence tool not properly configured.", "message": "Confluence tool not properly configured.",
} }
try: try:
metadata_service = ConfluenceToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = ConfluenceToolMetadataService(db_session)
search_space_id, user_id, page_title_or_id context = await metadata_service.get_update_context(
) search_space_id, user_id, page_title_or_id
)
if "error" in context: if "error" in context:
error_msg = context["error"] error_msg = context["error"]
if context.get("auth_expired"): if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
current_title = page_data["page_title"]
current_body = page_data.get("body", "")
current_version = page_data.get("version", 1)
document_id = page_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_update",
tool_name="update_confluence_page",
params={
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return { return {
"status": "auth_error", "status": "rejected",
"message": error_msg, "message": "User declined. Do not retry or suggest alternatives.",
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
} }
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"] final_page_id = result.params.get("page_id", page_id)
page_id = page_data["page_id"] final_title = result.params.get("new_title", new_title) or current_title
current_title = page_data["page_title"] final_content = result.params.get("new_content", new_content)
current_body = page_data.get("body", "") if final_content is None:
current_version = page_data.get("version", 1) final_content = current_body
document_id = page_data.get("document_id") final_version = result.params.get("version", current_version)
connector_id_from_context = context.get("account", {}).get("id") final_connector_id = result.params.get(
"connector_id", connector_id_from_context
result = request_approval(
action_type="confluence_page_update",
tool_name="update_confluence_page",
params={
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_title = result.params.get("new_title", new_title) or current_title
final_content = result.params.get("new_content", new_content)
if final_content is None:
final_content = current_body
final_version = result.params.get("version", current_version)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
) )
) final_document_id = result.params.get("document_id", document_id)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try: from sqlalchemy.future import select
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id from app.db import SearchSourceConnector, SearchSourceConnectorType
)
api_result = await client.update_page( if not final_connector_id:
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return { return {
"status": "insufficient_permissions", "status": "error",
"connector_id": final_connector_id, "message": "No connector found for this page.",
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
page_links = ( result = await db_session.execute(
api_result.get("_links", {}) if isinstance(api_result, dict) else {} select(SearchSourceConnector).filter(
) SearchSourceConnector.id == final_connector_id,
page_url = "" SearchSourceConnector.search_space_id == search_space_id,
if page_links.get("base") and page_links.get("webui"): SearchSourceConnector.user_id == user_id,
page_url = f"{page_links['base']}{page_links['webui']}" SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
) )
if kb_result["status"] == "success": )
kb_message_suffix = ( connector = result.scalars().first()
" Your knowledge base has also been updated." if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
api_result = await client.update_page(
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
) )
else: if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = ( kb_message_suffix = (
" The knowledge base will be updated in the next sync." " The knowledge base will be updated in the next sync."
) )
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return { return {
"status": "success", "status": "success",
"page_id": final_page_id, "page_id": final_page_id,
"page_url": page_url, "page_url": page_url,
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}", "message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
} }
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
) -> StructuredTool: ) -> StructuredTool:
"""Factory function to create the get_connected_accounts tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to scope account discovery to.
user_id: User ID to scope account discovery to.
Returns:
Configured StructuredTool for connected-accounts discovery.
"""
del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]: async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service) svc_cfg = MCP_SERVICES.get(service)
@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
except ValueError: except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter( result = await db_session.execute(
SearchSourceConnector.search_space_id == search_space_id, select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type == connector_type, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
) )
) connectors = result.scalars().all()
connectors = result.scalars().all()
if not connectors: if not connectors:
return [ return [
{ {
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
}
]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
} }
] if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
is_multi = len(connectors) > 1 return accounts
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return StructuredTool( return StructuredTool(
name="get_connected_accounts", name="get_connected_accounts",

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_discord_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_discord_channels tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_discord_channels() -> dict[str, Any]: async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server. """List text channels in the connected Discord server.
@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
Returns: Returns:
Dictionary with status and a list of channels (id, name). Dictionary with status and a list of channels (id, name).
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
} }
try: try:
connector = await get_discord_connector( async with async_session_maker() as db_session:
db_session, search_space_id, user_id connector = await get_discord_connector(
) db_session, search_space_id, user_id
if not connector:
return {"status": "error", "message": "No Discord connector found."}
guild_id = get_guild_id(connector)
if not guild_id:
return {
"status": "error",
"message": "No guild ID in Discord connector config.",
}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
) )
if not connector:
return {"status": "error", "message": "No Discord connector found."}
if resp.status_code == 401: guild_id = get_guild_id(connector)
return { if not guild_id:
"status": "auth_error", return {
"message": "Discord bot token is invalid.", "status": "error",
"connector_type": "discord", "message": "No guild ID in Discord connector config.",
} }
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
# Type 0 = text channel token = get_bot_token(connector)
channels = [
{"id": ch["id"], "name": ch["name"]} async with httpx.AsyncClient() as client:
for ch in resp.json() resp = await client.get(
if ch.get("type") == 0 f"{DISCORD_API}/guilds/{guild_id}/channels",
] headers={"Authorization": f"Bot {token}"},
return { timeout=15.0,
"status": "success", )
"guild_id": guild_id,
"channels": channels, if resp.status_code == 401:
"total": len(channels), return {
} "status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {
"status": "success",
"guild_id": guild_id,
"channels": channels,
"total": len(channels),
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_discord_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_discord_messages tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_discord_messages( async def read_discord_messages(
channel_id: str, channel_id: str,
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
Dictionary with status and a list of messages including Dictionary with status and a list of messages including
id, author, content, timestamp. id, author, content, timestamp.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
limit = min(limit, 50) limit = min(limit, 50)
try: try:
connector = await get_discord_connector( async with async_session_maker() as db_session:
db_session, search_space_id, user_id connector = await get_discord_connector(
) db_session, search_space_id, user_id
if not connector:
return {"status": "error", "message": "No Discord connector found."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
) )
if not connector:
return {"status": "error", "message": "No Discord connector found."}
if resp.status_code == 401: token = get_bot_token(connector)
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
messages = [ async with httpx.AsyncClient() as client:
{ resp = await client.get(
"id": m["id"], f"{DISCORD_API}/channels/{channel_id}/messages",
"author": m.get("author", {}).get("username", "Unknown"), headers={"Authorization": f"Bot {token}"},
"content": m.get("content", ""), params={"limit": limit},
"timestamp": m.get("timestamp", ""), timeout=15.0,
} )
for m in resp.json()
]
return { if resp.status_code == 401:
"status": "success", return {
"channel_id": channel_id, "status": "auth_error",
"messages": messages, "message": "Discord bot token is invalid.",
"total": len(messages), "connector_type": "discord",
} }
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
return {
"status": "success",
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector from ._auth import DISCORD_API, get_bot_token, get_discord_connector
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_discord_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_discord_message tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_discord_message( async def send_discord_message(
channel_id: str, channel_id: str,
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Discord tool not properly configured.", "message": "Discord tool not properly configured.",
@ -47,64 +65,65 @@ def create_send_discord_message_tool(
} }
try: try:
connector = await get_discord_connector( async with async_session_maker() as db_session:
db_session, search_space_id, user_id connector = await get_discord_connector(
) db_session, search_space_id, user_id
if not connector: )
return {"status": "error", "message": "No Discord connector found."} if not connector:
return {"status": "error", "message": "No Discord connector found."}
result = request_approval( result = request_approval(
action_type="discord_send_message", action_type="discord_send_message",
tool_name="send_discord_message", tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content}, params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id}, context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
) )
if resp.status_code == 401: if result.rejected:
return { return {
"status": "auth_error", "status": "rejected",
"message": "Discord bot token is invalid.", "message": "User declined. Message was not sent.",
"connector_type": "discord", }
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to send messages in this channel.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
msg_data = resp.json() final_content = result.params.get("content", content)
return { final_channel = result.params.get("channel_id", channel_id)
"status": "success",
"message_id": msg_data.get("id"), token = get_bot_token(connector)
"message": f"Message sent to channel {final_channel}.",
} async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to send messages in this channel.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient from app.connectors.dropbox.client import DropboxClient
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_dropbox_file( async def create_dropbox_file(
name: str, name: str,
@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'" f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Dropbox tool not properly configured.", "message": "Dropbox tool not properly configured.",
} }
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Dropbox accounts need re-authentication.",
"connector_type": "dropbox",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = DropboxClient(session=db_session, connector_id=cid)
items, err = await client.list_folder("")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{
"folder_path": item.get("path_lower", ""),
"name": item["name"],
}
for item in items
if item.get(".tag") == "folder" and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
"supported_types": _SUPPORTED_TYPES,
}
result = request_approval(
action_type="dropbox_file_creation",
tool_name="create_dropbox_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_path": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_path = result.params.get("parent_folder_path")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_extension(final_name, final_file_type)
if final_connector_id is not None:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR, == SearchSourceConnectorType.DROPBOX_CONNECTOR,
) )
) )
connector = result.scalars().first() connectors = result.scalars().all()
else:
connector = connectors[0]
if not connector: if not connectors:
return { return {
"status": "error", "status": "error",
"message": "Selected Dropbox connector is invalid.", "message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Dropbox accounts need re-authentication.",
"connector_type": "dropbox",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = DropboxClient(session=db_session, connector_id=cid)
items, err = await client.list_folder("")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{
"folder_path": item.get("path_lower", ""),
"name": item["name"],
}
for item in items
if item.get(".tag") == "folder" and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
"supported_types": _SUPPORTED_TYPES,
} }
client = DropboxClient(session=db_session, connector_id=connector.id) result = request_approval(
action_type="dropbox_file_creation",
parent_path = final_parent_folder_path or "" tool_name="create_dropbox_file",
file_path = ( params={
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}" "name": name,
) "file_type": file_type,
"content": content,
if final_file_type == "paper": "connector_id": None,
created = await client.create_paper_doc(file_path, final_content or "") "parent_folder_path": None,
file_id = created.get("file_id", "") },
web_url = created.get("url", "") context=context,
else:
docx_bytes = _markdown_to_docx(final_content or "")
created = await client.upload_file(
file_path, docx_bytes, mode="add", autorename=True
) )
file_id = created.get("id", "")
web_url = ""
logger.info(f"Dropbox file created: id={file_id}, name={final_name}") if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
kb_message_suffix = "" final_name = result.params.get("name", name)
try: final_file_type = result.params.get("file_type", file_type)
from app.services.dropbox import DropboxKBSyncService final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_path = result.params.get("parent_folder_path")
kb_service = DropboxKBSyncService(db_session) if not final_name or not final_name.strip():
kb_result = await kb_service.sync_after_create( return {"status": "error", "message": "File name cannot be empty."}
file_id=file_id,
file_name=final_name, final_name = _ensure_extension(final_name, final_file_type)
file_path=file_path,
web_url=web_url, if final_connector_id is not None:
content=final_content, result = await db_session.execute(
connector_id=connector.id, select(SearchSourceConnector).filter(
search_space_id=search_space_id, SearchSourceConnector.id == final_connector_id,
user_id=user_id, SearchSourceConnector.search_space_id == search_space_id,
) SearchSourceConnector.user_id == user_id,
if kb_result["status"] == "success": SearchSourceConnector.connector_type
kb_message_suffix = " Your knowledge base has also been updated." == SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connector = result.scalars().first()
else: else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync." connector = connectors[0]
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return { if not connector:
"status": "success", return {
"file_id": file_id, "status": "error",
"name": final_name, "message": "Selected Dropbox connector is invalid.",
"web_url": web_url, }
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
} client = DropboxClient(session=db_session, connector_id=connector.id)
parent_path = final_parent_folder_path or ""
file_path = (
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
)
if final_file_type == "paper":
created = await client.create_paper_doc(
file_path, final_content or ""
)
file_id = created.get("file_id", "")
web_url = created.get("url", "")
else:
docx_bytes = _markdown_to_docx(final_content or "")
created = await client.upload_file(
file_path, docx_bytes, mode="add", autorename=True
)
file_id = created.get("id", "")
web_url = ""
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
kb_message_suffix = ""
try:
from app.services.dropbox import DropboxKBSyncService
kb_service = DropboxKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=file_id,
file_name=final_name,
file_path=file_path,
web_url=web_url,
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": file_id,
"name": final_name,
"web_url": web_url,
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType, DocumentType,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
async_session_maker,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_dropbox_file( async def delete_dropbox_file(
file_name: str, file_name: str,
@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Dropbox tool not properly configured.", "message": "Dropbox tool not properly configured.",
} }
try: try:
doc_result = await db_session.execute( async with async_session_maker() as db_session:
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
doc_result = await db_session.execute( doc_result = await db_session.execute(
select(Document) select(Document)
.join( .join(
@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
and_( and_(
Document.search_space_id == search_space_id, Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE, Document.document_type == DocumentType.DROPBOX_FILE,
func.lower( func.lower(Document.title) == func.lower(file_name),
cast(
Document.document_metadata["dropbox_file_name"],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
) )
) )
@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
) )
document = doc_result.scalars().first() document = doc_result.scalars().first()
if not document: if not document:
return { doc_result = await db_session.execute(
"status": "not_found", select(Document)
"message": ( .join(
f"File '{file_name}' not found in your indexed Dropbox files. " SearchSourceConnector,
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " Document.connector_id == SearchSourceConnector.id,
"or (3) the file name is different." )
), .filter(
} and_(
Document.search_space_id == search_space_id,
if not document.connector_id: Document.document_type == DocumentType.DROPBOX_FILE,
return { func.lower(
"status": "error", cast(
"message": "Document has no associated connector.", Document.document_metadata["dropbox_file_name"],
} String,
)
meta = document.document_metadata or {} )
file_path = meta.get("dropbox_path") == func.lower(file_name),
file_id = meta.get("dropbox_file_id") SearchSourceConnector.user_id == user_id,
document_id = document.id )
)
if not file_path: .order_by(Document.updated_at.desc().nullslast())
return { .limit(1)
"status": "error",
"message": "File path is missing. Please re-index the file.",
}
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
) )
) document = doc_result.scalars().first()
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Dropbox connector not found or access denied.",
}
cfg = connector.config or {} if not document:
if cfg.get("auth_expired"): return {
return { "status": "not_found",
"status": "auth_error", "message": (
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.", f"File '{file_name}' not found in your indexed Dropbox files. "
"connector_type": "dropbox", "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
} "or (3) the file name is different."
),
}
context = { if not document.connector_id:
"file": { return {
"file_id": file_id, "status": "error",
"file_path": file_path, "message": "Document has no associated connector.",
"name": file_name, }
"document_id": document_id,
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
result = request_approval( meta = document.document_metadata or {}
action_type="dropbox_file_trash", file_path = meta.get("dropbox_path")
tool_name="delete_dropbox_file", file_id = meta.get("dropbox_file_id")
params={ document_id = document.id
"file_path": file_path,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected: if not file_path:
return { return {
"status": "rejected", "status": "error",
"message": "User declined. Do not retry or suggest alternatives.", "message": "File path is missing. Please re-index the file.",
} }
final_file_path = result.params.get("file_path", file_path) conn_result = await db_session.execute(
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
and_( and_(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
) )
) )
) )
validated_connector = result.scalars().first() connector = conn_result.scalars().first()
if not validated_connector: if not connector:
return { return {
"status": "error", "status": "error",
"message": "Selected Dropbox connector is invalid or has been disconnected.", "message": "Dropbox connector not found or access denied.",
} }
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info( cfg = connector.config or {}
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}" if cfg.get("auth_expired"):
) return {
"status": "auth_error",
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "dropbox",
}
client = DropboxClient(session=db_session, connector_id=actual_connector_id) context = {
await client.delete_file(final_file_path) "file": {
"file_id": file_id,
"file_path": file_path,
"name": file_name,
"document_id": document_id,
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
logger.info(f"Dropbox file deleted: path={final_file_path}") result = request_approval(
action_type="dropbox_file_trash",
trash_result: dict[str, Any] = { tool_name="delete_dropbox_file",
"status": "success", params={
"file_id": file_id, "file_path": file_path,
"message": f"Successfully deleted '{file_name}' from Dropbox.", "connector_id": connector.id,
} "delete_from_kb": delete_from_kb,
},
deleted_from_kb = False context=context,
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File deleted, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
) )
return trash_result if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_path = result.params.get("file_path", file_path)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
)
client = DropboxClient(
session=db_session, connector_id=actual_connector_id
)
await client.delete_file(final_file_path)
logger.info(f"Dropbox file deleted: path={final_file_path}")
trash_result: dict[str, Any] = {
"status": "success",
"file_id": file_id,
"message": f"Successfully deleted '{file_name}' from Dropbox.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File deleted, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService, ImageGenRouterService,
is_image_gen_auto_mode, is_image_gen_auto_mode,
) )
from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
} }
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string( def _build_model_string(
provider: str, model_name: str, custom_provider: str | None provider: str, model_name: str, custom_provider: str | None
) -> str: ) -> str:
if custom_provider: prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}" return f"{prefix}/{model_name}"
@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found" "error": f"Image generation config {config_id} not found"
} }
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("provider", ""), cfg.get("custom_provider")
cfg["model_name"],
cfg.get("custom_provider"),
) )
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key") gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"): api_base = resolve_api_base(
gen_kwargs["api_base"] = cfg["api_base"] provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"): if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"] gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"): if cfg.get("litellm_params"):
@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found" "error": f"Image generation config {config_id} not found"
} }
model_string = _build_model_string( provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.provider.value, db_cfg.custom_provider
db_cfg.model_name,
db_cfg.custom_provider,
) )
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base: api_base = resolve_api_base(
gen_kwargs["api_base"] = db_cfg.api_base provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version: if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params: if db_cfg.litellm_params:

View file

@ -0,0 +1,41 @@
from typing import Any
from app.db import SearchSourceConnector
from app.services.composio_service import ComposioService
def split_recipients(value: str | None) -> list[str]:
if not value:
return []
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
def unwrap_composio_data(data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def execute_composio_gmail_tool(
connector: SearchSourceConnector,
user_id: str,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return None, "Composio connected account ID not found for this Gmail connector."
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return unwrap_composio_data(result.get("data")), None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_gmail_draft( async def create_gmail_draft(
to: str, to: str,
@ -57,246 +75,276 @@ def create_create_gmail_draft_tool(
""" """
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'") logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GmailToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = GmailToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
return { logger.error(
"status": "error", f"Failed to fetch creation context: {context['error']}"
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
) )
) return {"status": "error", "message": context["error"]}
connector = result.scalars().first()
if not connector: accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return { return {
"status": "error", "status": "auth_error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
} }
actual_connector_id = connector.id
logger.info( logger.info(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
) )
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if ( if result.rejected:
connector.connector_type return {
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR "status": "rejected",
): "message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
from app.utils.google_credentials import build_composio_credentials }
cca_id = connector.config.get("composio_connected_account_id") final_to = result.params.get("to", to)
if cca_id: final_subject = result.params.get("subject", subject)
creds = build_composio_credentials(cca_id) final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else: else:
return { result = await db_session.execute(
"status": "error", select(SearchSourceConnector).filter(
"message": "Composio connected account ID not found for this Gmail connector.", SearchSourceConnector.search_space_id == search_space_id,
} SearchSourceConnector.user_id == user_id,
else: SearchSourceConnector.connector_type.in_(_gmail_types),
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
) )
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: connector = result.scalars().first()
from sqlalchemy.orm.attributes import flag_modified if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
_res = await db_session.execute( logger.info(
select(SearchSourceConnector).where( f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
SearchSourceConnector.id == actual_connector_id )
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
) )
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
) )
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"): created, error = await execute_composio_gmail_tool(
_conn.config = {**_conn.config, "auth_expired": True} connector,
flag_modified(_conn, "config") user_id,
await db_session.commit() "GMAIL_CREATE_EMAIL_DRAFT",
except Exception: {
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(created, dict):
created = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail draft created: id={created.get('id')}") _res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
kb_message_suffix = "" logger.info(f"Gmail draft created: id={created.get('id')}")
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session) kb_message_suffix = ""
draft_message = created.get("message", {}) try:
kb_result = await kb_service.sync_after_create( from app.services.gmail import GmailKBSyncService
message_id=draft_message.get("id", ""),
thread_id=draft_message.get("threadId", ""), kb_service = GmailKBSyncService(db_session)
subject=final_subject, draft_message = created.get("message", {})
sender="me", kb_result = await kb_service.sync_after_create(
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), message_id=draft_message.get("id", ""),
body_text=final_body, thread_id=draft_message.get("threadId", ""),
connector_id=actual_connector_id, subject=final_subject,
search_space_id=search_space_id, sender="me",
user_id=user_id, date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
draft_id=created.get("id"), body_text=final_body,
) connector_id=actual_connector_id,
if kb_result["status"] == "success": search_space_id=search_space_id,
kb_message_suffix = " Your knowledge base has also been updated." user_id=user_id,
else: draft_id=created.get("id"),
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
return { return {
"status": "success", "status": "success",
"draft_id": created.get("id"), "draft_id": created.get("id"),
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}", "message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
} }
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -5,7 +5,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_gmail_email(message_id: str) -> dict[str, Any]: async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID. """Read the full content of a specific Gmail email by its message ID.
@ -32,60 +49,115 @@ def create_read_gmail_email_tool(
Returns: Returns:
Dictionary with status and the full email content formatted as markdown. Dictionary with status and the full email content formatted as markdown.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."} return {"status": "error", "message": "Gmail tool not properly configured."}
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter( result = await db_session.execute(
SearchSourceConnector.search_space_id == search_space_id, select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
) )
) connector = result.scalars().first()
connector = result.scalars().first() if not connector:
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return { return {
"status": "auth_error", "status": "error",
"message": error, "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
"connector_type": "gmail",
} }
return {"status": "error", "message": error}
if not detail: if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.agents.new_chat.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
service = ComposioService()
detail, error = await service.get_gmail_message_detail(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if error:
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
summary = _format_gmail_summary(detail)
content = (
f"# {summary['subject']}\n\n"
f"**From:** {summary['from']}\n"
f"**To:** {summary['to']}\n"
f"**Date:** {summary['date']}\n\n"
f"## Message Content\n\n"
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {summary['message_id']}\n"
f"- **Thread ID:** {summary['thread_id']}\n"
)
return {
"status": "success",
"message_id": summary["message_id"] or message_id,
"content": content,
}
from app.agents.new_chat.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
content = gmail.format_message_to_markdown(detail)
return { return {
"status": "not_found", "status": "success",
"message": f"Email with ID '{message_id}' not found.", "message_id": message_id,
"content": content,
} }
content = gmail.format_message_to_markdown(detail)
return {"status": "success", "message_id": message_id, "content": content}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,7 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials raise ValueError("Composio connectors must use Composio tool execution.")
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
) )
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
headers = message.get("payload", {}).get("headers", [])
return {
header.get("name", "").lower(): header.get("value", "")
for header in headers
if isinstance(header, dict)
}
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
headers = _gmail_headers(message)
return {
"message_id": message.get("id") or message.get("messageId"),
"thread_id": message.get("threadId"),
"subject": message.get("subject") or headers.get("subject", "No Subject"),
"from": message.get("sender") or headers.get("from", "Unknown"),
"to": message.get("to") or headers.get("to", ""),
"date": message.get("messageTimestamp") or headers.get("date", ""),
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
"labels": message.get("labelIds", []),
}
async def _search_composio_gmail(
connector: SearchSourceConnector,
user_id: str,
query: str,
max_results: int,
) -> dict[str, Any]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.services.composio_service import ComposioService
service = ComposioService()
messages, _next_token, _estimate, error = await service.get_gmail_messages(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=max_results,
)
if error:
return {"status": "error", "message": error}
emails = [_format_gmail_summary(message) for message in messages]
return {
"status": "success",
"emails": emails,
"total": len(emails),
"message": "No emails found." if not emails else None,
}
def create_search_gmail_tool( def create_search_gmail_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the search_gmail tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_gmail tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def search_gmail( async def search_gmail(
query: str, query: str,
@ -90,83 +159,92 @@ def create_search_gmail_tool(
Dictionary with status and a list of email summaries including Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet. message_id, subject, from, date, snippet.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."} return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20) max_results = min(max_results, 20)
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter( result = await db_session.execute(
SearchSourceConnector.search_space_id == search_space_id, select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
) )
) connector = result.scalars().first()
connector = result.scalars().first() if not connector:
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return { return {
"status": "auth_error", "status": "error",
"message": error, "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
"connector_type": "gmail",
} }
return {"status": "error", "message": error}
if not messages_list: if (
return { connector.connector_type
"status": "success", == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
"emails": [], ):
"total": 0, return await _search_composio_gmail(
"message": "No emails found.", connector, str(user_id), query, max_results
} )
emails = [] creds = _build_credentials(connector)
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"]) from app.connectors.google_gmail_connector import GoogleGmailConnector
if err:
continue gmail = GoogleGmailConnector(
headers = { credentials=creds,
h["name"].lower(): h["value"] session=db_session,
for h in detail.get("payload", {}).get("headers", []) user_id=user_id,
} connector_id=connector.id,
emails.append(
{
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
}
) )
return {"status": "success", "emails": emails, "total": len(emails)} messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not messages_list:
return {
"status": "success",
"emails": [],
"total": 0,
"message": "No emails found.",
}
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append(
{
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
}
)
return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_gmail_email( async def send_gmail_email(
to: str, to: str,
@ -58,247 +76,277 @@ def create_send_gmail_email_tool(
""" """
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GmailToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = GmailToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
return { logger.error(
"status": "error", f"Failed to fetch creation context: {context['error']}"
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
) )
) return {"status": "error", "message": context["error"]}
connector = result.scalars().first()
if not connector: accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return { return {
"status": "error", "status": "auth_error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
} }
actual_connector_id = connector.id
logger.info( logger.info(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}" f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
) )
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if ( if result.rejected:
connector.connector_type return {
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR "status": "rejected",
): "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
from app.utils.google_credentials import build_composio_credentials }
cca_id = connector.config.get("composio_connected_account_id") final_to = result.params.get("to", to)
if cca_id: final_subject = result.params.get("subject", subject)
creds = build_composio_credentials(cca_id) final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else: else:
return { result = await db_session.execute(
"status": "error", select(SearchSourceConnector).filter(
"message": "Composio connected account ID not found for this Gmail connector.", SearchSourceConnector.search_space_id == search_space_id,
} SearchSourceConnector.user_id == user_id,
else: SearchSourceConnector.connector_type.in_(_gmail_types),
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
) )
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: connector = result.scalars().first()
from sqlalchemy.orm.attributes import flag_modified if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
_res = await db_session.execute( logger.info(
select(SearchSourceConnector).where( f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
SearchSourceConnector.id == actual_connector_id )
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
) )
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
) )
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"): sent, error = await execute_composio_gmail_tool(
_conn.config = {**_conn.config, "auth_expired": True} connector,
flag_modified(_conn, "config") user_id,
await db_session.commit() "GMAIL_SEND_EMAIL",
except Exception: {
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(sent, dict):
sent = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info( _res = await db_session.execute(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" select(SearchSourceConnector).where(
) SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
kb_message_suffix = "" logger.info(
try: f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
) )
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return { kb_message_suffix = ""
"status": "success", try:
"message_id": sent.get("id"), from app.services.gmail import GmailKBSyncService
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", kb_service = GmailKBSyncService(db_session)
} kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"message_id": sent.get("id"),
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the trash_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured trash_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def trash_gmail_email( async def trash_gmail_email(
email_subject_or_id: str, email_subject_or_id: str,
@ -55,244 +73,261 @@ def create_trash_gmail_email_tool(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}" f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GmailToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_trash_context( metadata_service = GmailToolMetadataService(db_session)
search_space_id, user_id, email_subject_or_id context = await metadata_service.get_trash_context(
) search_space_id, user_id, email_subject_or_id
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
) )
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"] if "error" in context:
message_id = email["message_id"] error_msg = context["error"]
document_id = email.get("document_id") if "not found" in error_msg.lower():
connector_id_from_context = context["account"]["id"] logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
if not message_id: account = context.get("account", {})
return { if account.get("auth_expired"):
"status": "error", logger.warning(
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.", "Gmail account %s has expired authentication",
} account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info( email = context["email"]
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})" message_id = email["message_id"]
) document_id = email.get("document_id")
result = request_approval( connector_id_from_context = context["account"]["id"]
action_type="gmail_email_trash",
tool_name="trash_gmail_email",
params={
"message_id": message_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected: if not message_id:
return {
"status": "rejected",
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
final_message_id = result.params.get("message_id", message_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
} }
else:
from google.oauth2.credentials import Credentials
from app.config import config logger.info(
from app.utils.oauth_security import TokenEncryption f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
)
config_data = dict(connector.config) result = request_approval(
token_encrypted = config_data.get("_token_encrypted", False) action_type="gmail_email_trash",
if token_encrypted and config.SECRET_KEY: tool_name="trash_gmail_email",
token_encryption = TokenEncryption(config.SECRET_KEY) params={
if config_data.get("token"): "message_id": message_id,
config_data["token"] = token_encryption.decrypt_token( "connector_id": connector_id_from_context,
config_data["token"] "delete_from_kb": delete_from_kb,
) },
if config_data.get("refresh_token"): context=context,
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build if result.rejected:
gmail_service = build("gmail", "v1", credentials=creds)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "rejected",
"connector_id": connector.id, "message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}") final_message_id = result.params.get("message_id", message_id)
final_connector_id = result.params.get(
trash_result: dict[str, Any] = { "connector_id", connector_id_from_context
"status": "success", )
"message_id": final_message_id, final_delete_from_kb = result.params.get(
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.", "delete_from_kb", delete_from_kb
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
) )
return trash_result if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
_trashed, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_MOVE_TO_TRASH",
{"user_id": "me", "message_id": final_message_id},
)
if error:
raise RuntimeError(error)
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}")
trash_result: dict[str, Any] = {
"status": "success",
"message_id": final_message_id,
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the update_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_gmail_draft( async def update_gmail_draft(
draft_subject_or_id: str, draft_subject_or_id: str,
@ -76,294 +94,329 @@ def create_update_gmail_draft_tool(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'" f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Gmail tool not properly configured. Please contact support.", "message": "Gmail tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GmailToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = GmailToolMetadataService(db_session)
search_space_id, user_id, draft_subject_or_id context = await metadata_service.get_update_context(
) search_space_id, user_id, draft_subject_or_id
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
result = request_approval(
action_type="gmail_draft_update",
tool_name="update_gmail_draft",
params={
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", final_to_default)
final_subject = result.params.get("subject", final_subject_default)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_draft_id = result.params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this draft.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
gmail_service = build("gmail", "v1", credentials=creds) account = context.get("account", {})
if account.get("auth_expired"):
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning( logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}" "Gmail account %s has expired authentication",
account.get("id"),
) )
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "auth_error",
"connector_id": connector.id, "message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", "connector_type": "gmail",
} }
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
result = request_approval(
action_type="gmail_draft_update",
tool_name="update_gmail_draft",
params={
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", final_to_default)
final_subject = result.params.get("subject", final_subject_default)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_draft_id = result.params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return { return {
"status": "error", "status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.", "message": "No connector found for this draft.",
} }
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}") from sqlalchemy.future import select
kb_message_suffix = "" from app.db import SearchSourceConnector, SearchSourceConnectorType
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import Document _gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
doc_result = await db_session.execute( result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id) select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
) )
document = doc_result.scalars().first() )
if document: connector = result.scalars().first()
document.source_markdown = final_body if not connector:
document.title = final_subject return {
meta = dict(document.document_metadata or {}) "status": "error",
meta["subject"] = final_subject "message": "Selected Gmail connector is invalid or has been disconnected.",
meta["draft_id"] = updated.get("id", final_draft_id) }
updated_msg = updated.get("message", {})
if updated_msg.get("id"): logger.info(
meta["message_id"] = updated_msg["id"] f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
document.document_metadata = meta )
flag_modified(document, "document_metadata")
await db_session.commit() is_composio_gmail = (
kb_message_suffix = ( connector.connector_type
" Your knowledge base has also been updated." == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
) )
logger.info( if is_composio_gmail:
f"KB document {document_id} updated for draft {final_draft_id}" cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
if is_composio_gmail:
final_draft_id = await _find_composio_draft_id_by_message(
connector, user_id, message_id
) )
else: else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync." from googleapiclient.discovery import build
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
return { gmail_service = build("gmail", "v1", credentials=creds)
"status": "success", final_draft_id = await _find_draft_id_by_message(
"draft_id": updated.get("id"), gmail_service, message_id
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}", )
}
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
updated, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_UPDATE_DRAFT",
{
"user_id": "me",
"draft_id": final_draft_id,
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(updated, dict):
updated = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}")
kb_message_suffix = ""
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import Document
doc_result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
document.source_markdown = final_body
document.title = final_subject
meta = dict(document.document_metadata or {})
meta["subject"] = final_subject
meta["draft_id"] = updated.get("id", final_draft_id)
updated_msg = updated.get("message", {})
if updated_msg.get("id"):
meta["message_id"] = updated_msg["id"]
document.document_metadata = meta
flag_modified(document, "document_metadata")
await db_session.commit()
kb_message_suffix = (
" Your knowledge base has also been updated."
)
logger.info(
f"KB document {document_id} updated for draft {final_draft_id}"
)
else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": updated.get("id"),
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
except Exception as e: except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}") logger.warning(f"Failed to look up draft by message_id: {e}")
return None return None
async def _find_composio_draft_id_by_message(
connector: Any, user_id: str, message_id: str
) -> str | None:
from app.agents.new_chat.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await execute_composio_gmail_tool(
connector, user_id, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_calendar_event( async def create_calendar_event(
summary: str, summary: str,
@ -60,254 +78,294 @@ def create_create_calendar_event_tool(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'" f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GoogleCalendarToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = GoogleCalendarToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Calendar accounts have expired authentication"
) )
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info( if "error" in context:
f"Requesting approval for creating calendar event: summary='{summary}'" logger.error(
) f"Failed to fetch creation context: {context['error']}"
result = request_approval(
action_type="google_calendar_event_creation",
tool_name="create_calendar_event",
params={
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": description,
"location": location,
"attendees": attendees,
"timezone": context.get("timezone"),
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
final_summary = result.params.get("summary", summary)
final_start_datetime = result.params.get("start_datetime", start_datetime)
final_end_datetime = result.params.get("end_datetime", end_datetime)
final_description = result.params.get("description", description)
final_location = result.params.get("location", location)
final_attendees = result.params.get("attendees", attendees)
final_connector_id = result.params.get("connector_id")
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Event summary cannot be empty."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
) )
) return {"status": "error", "message": context["error"]}
connector = result.scalars().first()
if not connector: accounts = context.get("accounts", [])
return { if accounts and all(a.get("auth_expired") for a in accounts):
"status": "error", logger.warning(
"message": "Selected Google Calendar connector is invalid or has been disconnected.", "All Google Calendar accounts have expired authentication"
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
) )
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for creating calendar event: summary='{summary}'"
) )
connector = result.scalars().first() result = request_approval(
if not connector: action_type="google_calendar_event_creation",
return { tool_name="create_calendar_event",
"status": "error", params={
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", "summary": summary,
} "start_datetime": start_datetime,
actual_connector_id = connector.id "end_datetime": end_datetime,
"description": description,
logger.info( "location": location,
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" "attendees": attendees,
) "timezone": context.get("timezone"),
"connector_id": None,
if ( },
connector.connector_type context=context,
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( if result.rejected:
None, lambda: build("calendar", "v3", credentials=creds) return {
) "status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
tz = context.get("timezone", "UTC") final_summary = result.params.get("summary", summary)
event_body: dict[str, Any] = { final_start_datetime = result.params.get(
"summary": final_summary, "start_datetime", start_datetime
"start": {"dateTime": final_start_datetime, "timeZone": tz}, )
"end": {"dateTime": final_end_datetime, "timeZone": tz}, final_end_datetime = result.params.get("end_datetime", end_datetime)
} final_description = result.params.get("description", description)
if final_description: final_location = result.params.get("location", location)
event_body["description"] = final_description final_attendees = result.params.get("attendees", attendees)
if final_location: final_connector_id = result.params.get("connector_id")
event_body["location"] = final_location
if final_attendees: if not final_summary or not final_summary.strip():
event_body["attendees"] = [ return {
{"email": e.strip()} for e in final_attendees if e.strip() "status": "error",
"message": "Event summary cannot be empty.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
] ]
try: if final_connector_id is not None:
created = await asyncio.get_event_loop().run_in_executor( result = await db_session.execute(
None, select(SearchSourceConnector).filter(
lambda: ( SearchSourceConnector.id == final_connector_id,
service.events() SearchSourceConnector.search_space_id == search_space_id,
.insert(calendarId="primary", body=event_body) SearchSourceConnector.user_id == user_id,
.execute() SearchSourceConnector.connector_type.in_(_calendar_types),
), )
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: connector = result.scalars().first()
from sqlalchemy.orm.attributes import flag_modified if not connector:
return {
_res = await db_session.execute( "status": "error",
select(SearchSourceConnector).where( "message": "Selected Google Calendar connector is invalid or has been disconnected.",
SearchSourceConnector.id == actual_connector_id }
) actual_connector_id = connector.id
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else: else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." result = await db_session.execute(
except Exception as kb_err: select(SearchSourceConnector).filter(
logger.warning(f"KB sync after create failed: {kb_err}") SearchSourceConnector.search_space_id == search_space_id,
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync." SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
actual_connector_id = connector.id
return { logger.info(
"status": "success", f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
"event_id": created.get("id"), )
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}", is_composio_calendar = (
} connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params = {
"calendar_id": "primary",
"summary": final_summary,
"start_datetime": final_start_datetime,
"end_datetime": final_end_datetime,
"timezone": tz,
"attendees": final_attendees or [],
}
if final_description:
composio_params["description"] = final_description
if final_location:
composio_params["location"] = final_location
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_CREATE_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
created = composio_result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"event_id": created.get("id"),
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_calendar_event( async def delete_calendar_event(
event_title_or_id: str, event_title_or_id: str,
@ -54,240 +72,258 @@ def create_delete_calendar_event_tool(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}" f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GoogleCalendarToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_deletion_context( metadata_service = GoogleCalendarToolMetadataService(db_session)
search_space_id, user_id, event_title_or_id context = await metadata_service.get_deletion_context(
) search_space_id, user_id, event_title_or_id
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Calendar account %s has expired authentication",
account.get("id"),
) )
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"] if "error" in context:
event_id = event["event_id"] error_msg = context["error"]
document_id = event.get("document_id") if "not found" in error_msg.lower():
connector_id_from_context = context["account"]["id"] logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
if not event_id: account = context.get("account", {})
return { if account.get("auth_expired"):
"status": "error", logger.warning(
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.", "Google Calendar account %s has expired authentication",
} account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info( event = context["event"]
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})" event_id = event["event_id"]
) document_id = event.get("document_id")
result = request_approval( connector_id_from_context = context["account"]["id"]
action_type="google_calendar_event_deletion",
tool_name="delete_calendar_event",
params={
"event_id": event_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected: if not event_id:
return {
"status": "rejected",
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
} }
else:
config_data = dict(connector.config)
from app.config import config as app_config logger.info(
from app.utils.oauth_security import TokenEncryption f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
)
token_encrypted = config_data.get("_token_encrypted", False) result = request_approval(
if token_encrypted and app_config.SECRET_KEY: action_type="google_calendar_event_deletion",
token_encryption = TokenEncryption(app_config.SECRET_KEY) tool_name="delete_calendar_event",
for key in ("token", "refresh_token", "client_secret"): params={
if config_data.get(key): "event_id": event_id,
config_data[key] = token_encryption.decrypt_token( "connector_id": connector_id_from_context,
config_data[key] "delete_from_kb": delete_from_kb,
) },
context=context,
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( if result.rejected:
None, lambda: build("calendar", "v3", credentials=creds)
)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "rejected",
"connector_id": actual_connector_id, "message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}") final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
delete_result: dict[str, Any] = { "connector_id", connector_id_from_context
"status": "success", )
"event_id": final_event_id, final_delete_from_kb = result.params.get(
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.", "delete_from_kb", delete_from_kb
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
) )
return delete_result if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_DELETE_EVENT",
params={
"calendar_id": "primary",
"event_id": final_event_id,
},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}")
delete_result: dict[str, Any] = {
"status": "success",
"event_id": final_event_id,
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
)
return delete_result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
] ]
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
if "T" in value:
return value
time = "23:59:59" if is_end else "00:00:00"
return f"{value}T{time}Z"
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return events
def create_search_calendar_events_tool( def create_search_calendar_events_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the search_calendar_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_calendar_events tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def search_calendar_events( async def search_calendar_events(
start_date: str, start_date: str,
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
Dictionary with status and a list of events including Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees. event_id, summary, start, end, location, attendees.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Calendar tool not properly configured.", "message": "Calendar tool not properly configured.",
@ -47,76 +93,85 @@ def create_search_calendar_events_tool(
max_results = min(max_results, 50) max_results = min(max_results, 50)
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter( result = await db_session.execute(
SearchSourceConnector.search_space_id == search_space_id, select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
)
) )
) connector = result.scalars().first()
connector = result.scalars().first() if not connector:
if not connector: return {
return { "status": "error",
"status": "error", "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", }
}
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import GoogleCalendarConnector
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if ( if (
"re-authenticate" in error.lower() connector.connector_type
or "authentication failed" in error.lower() == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): ):
return { cca_id = connector.config.get("composio_connected_account_id")
"status": "auth_error", if not cca_id:
"message": error, return {
"connector_type": "google_calendar", "status": "error",
} "message": "Composio connected account ID not found for this connector.",
if "no events found" in error.lower(): }
return {
"status": "success",
"events": [],
"total": 0,
"message": error,
}
return {"status": "error", "message": error}
events = [] from app.services.composio_service import ComposioService
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return {"status": "success", "events": events, "total": len(events)} events_raw, error = await ComposioService().get_calendar_events(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
time_min=_to_calendar_boundary(start_date, is_end=False),
time_max=_to_calendar_boundary(end_date, is_end=True),
max_results=max_results,
)
if not events_raw and not error:
error = "No events found in the specified date range."
else:
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import (
GoogleCalendarConnector,
)
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "google_calendar",
}
if "no events found" in error.lower():
return {
"status": "success",
"events": [],
"total": 0,
"message": error,
}
return {"status": "error", "message": error}
events = _format_calendar_events(events_raw)
return {"status": "success", "events": events, "total": len(events)}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -9,6 +9,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the update_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_calendar_event( async def update_calendar_event(
event_title_or_id: str, event_title_or_id: str,
@ -74,272 +92,317 @@ def create_update_calendar_event_tool(
""" """
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'") logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.", "message": "Google Calendar tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GoogleCalendarToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = GoogleCalendarToolMetadataService(db_session)
search_space_id, user_id, event_title_or_id context = await metadata_service.get_update_context(
) search_space_id, user_id, event_title_or_id
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if context.get("auth_expired"):
logger.warning("Google Calendar account has expired authentication")
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
result = request_approval(
action_type="google_calendar_event_update",
tool_name="update_calendar_event",
params={
"event_id": event_id,
"document_id": document_id,
"connector_id": connector_id_from_context,
"new_summary": new_summary,
"new_start_datetime": new_start_datetime,
"new_end_datetime": new_end_datetime,
"new_description": new_description,
"new_location": new_location,
"new_attendees": new_attendees,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_new_summary = result.params.get("new_summary", new_summary)
final_new_start_datetime = result.params.get(
"new_start_datetime", new_start_datetime
)
final_new_end_datetime = result.params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
) )
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
logger.info( if context.get("auth_expired"):
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" logger.warning("Google Calendar account has expired authentication")
) return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
if ( event = context["event"]
connector.connector_type event_id = event["event_id"]
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR document_id = event.get("document_id")
): connector_id_from_context = context["account"]["id"]
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") if not event_id:
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
} }
else:
config_data = dict(connector.config)
from app.config import config as app_config logger.info(
from app.utils.oauth_security import TokenEncryption f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
token_encrypted = config_data.get("_token_encrypted", False) result = request_approval(
if token_encrypted and app_config.SECRET_KEY: action_type="google_calendar_event_update",
token_encryption = TokenEncryption(app_config.SECRET_KEY) tool_name="update_calendar_event",
for key in ("token", "refresh_token", "client_secret"): params={
if config_data.get(key): "event_id": event_id,
config_data[key] = token_encryption.decrypt_token( "document_id": document_id,
config_data[key] "connector_id": connector_id_from_context,
) "new_summary": new_summary,
"new_start_datetime": new_start_datetime,
exp = config_data.get("expiry", "") "new_end_datetime": new_end_datetime,
if exp: "new_description": new_description,
exp = exp.replace("Z", "") "new_location": new_location,
"new_attendees": new_attendees,
creds = Credentials( },
token=config_data.get("token"), context=context,
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( if result.rejected:
None, lambda: build("calendar", "v3", credentials=creds) return {
) "status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
update_body: dict[str, Any] = {} final_event_id = result.params.get("event_id", event_id)
if final_new_summary is not None: final_connector_id = result.params.get(
update_body["summary"] = final_new_summary "connector_id", connector_id_from_context
if final_new_start_datetime is not None:
update_body["start"] = _build_time_body(
final_new_start_datetime, context
) )
if final_new_end_datetime is not None: final_new_summary = result.params.get("new_summary", new_summary)
update_body["end"] = _build_time_body(final_new_end_datetime, context) final_new_start_datetime = result.params.get(
if final_new_description is not None: "new_start_datetime", new_start_datetime
update_body["description"] = final_new_description )
if final_new_location is not None: final_new_end_datetime = result.params.get(
update_body["location"] = final_new_location "new_end_datetime", new_end_datetime
if final_new_attendees is not None: )
update_body["attendees"] = [ final_new_description = result.params.get(
{"email": e.strip()} for e in final_new_attendees if e.strip() "new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
] ]
if not update_body: result = await db_session.execute(
return { select(SearchSourceConnector).filter(
"status": "error", SearchSourceConnector.id == final_connector_id,
"message": "No changes specified. Please provide at least one field to update.", SearchSourceConnector.search_space_id == search_space_id,
} SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
try: )
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
) )
except Exception as api_err: connector = result.scalars().first()
from googleapiclient.errors import HttpError if not connector:
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "error",
"connector_id": actual_connector_id, "message": "Selected Google Calendar connector is invalid or has been disconnected.",
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
logger.info(f"Calendar event updated: event_id={final_event_id}") actual_connector_id = connector.id
kb_message_suffix = "" logger.info(
if document_id is not None: f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
try: )
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session) is_composio_calendar = (
kb_result = await kb_service.sync_after_update( connector.connector_type
document_id=document_id, == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
event_id=final_event_id, )
connector_id=actual_connector_id, if is_composio_calendar:
search_space_id=search_space_id, cca_id = connector.config.get("composio_connected_account_id")
user_id=user_id, if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
) )
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return { update_body: dict[str, Any] = {}
"status": "success", if final_new_summary is not None:
"event_id": final_event_id, update_body["summary"] = final_new_summary
"html_link": updated.get("htmlLink"), if final_new_start_datetime is not None:
"message": f"Successfully updated the calendar event.{kb_message_suffix}", update_body["start"] = _build_time_body(
} final_new_start_datetime, context
)
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
if not update_body:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params: dict[str, Any] = {
"calendar_id": "primary",
"event_id": final_event_id,
}
if final_new_summary is not None:
composio_params["summary"] = final_new_summary
if final_new_start_datetime is not None:
composio_params["start_time"] = final_new_start_datetime
if final_new_end_datetime is not None:
composio_params["end_time"] = final_new_end_datetime
if final_new_description is not None:
composio_params["description"] = final_new_description
if final_new_location is not None:
composio_params["location"] = final_new_location
if final_new_attendees is not None:
composio_params["attendees"] = [
e.strip() for e in final_new_attendees if e.strip()
]
if not _is_date_only(
final_new_start_datetime or final_new_end_datetime or ""
):
composio_params["timezone"] = context.get("timezone", "UTC")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_PATCH_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
updated = composio_result.get("data", {})
if isinstance(updated, dict):
updated = updated.get("data", updated)
if isinstance(updated, dict):
updated = updated.get("response_data", updated)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}")
kb_message_suffix = ""
if document_id is not None:
try:
from app.services.google_calendar import (
GoogleCalendarKBSyncService,
)
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
event_id=final_event_id,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return {
"status": "success",
"event_id": final_event_id,
"html_link": updated.get("htmlLink"),
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured create_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_google_drive_file( async def create_google_drive_file(
name: str, name: str,
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
f"create_google_drive_file called: name='{name}', type='{file_type}'" f"create_google_drive_file called: name='{name}', type='{file_type}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Drive tool not properly configured. Please contact support.", "message": "Google Drive tool not properly configured. Please contact support.",
@ -78,195 +98,232 @@ def create_create_google_drive_file_tool(
} }
try: try:
metadata_service = GoogleDriveToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = GoogleDriveToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
)
if "error" in context: if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(
return {"status": "error", "message": context["error"]} f"Failed to fetch creation context: {context['error']}"
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Google Drive accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
)
result = request_approval(
action_type="google_drive_file_creation",
tool_name="create_google_drive_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
) )
) return {"status": "error", "message": context["error"]}
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
logger.info( accounts = context.get("accounts", [])
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" if accounts and all(a.get("auth_expired") for a in accounts):
)
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
credentials=pre_built_creds,
)
try:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning( logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}" "All Google Drive accounts have expired authentication"
) )
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "auth_error",
"connector_id": actual_connector_id, "message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", "connector_type": "google_drive",
} }
raise
logger.info( logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
) )
result = request_approval(
kb_message_suffix = "" action_type="google_drive_file_creation",
try: tool_name="create_google_drive_file",
from app.services.google_drive import GoogleDriveKBSyncService params={
"name": name,
kb_service = GoogleDriveKBSyncService(db_session) "file_type": file_type,
kb_result = await kb_service.sync_after_create( "content": content,
file_id=created.get("id"), "connector_id": None,
file_name=created.get("name", final_name), "parent_folder_id": None,
mime_type=mime_type, },
web_view_link=created.get("webViewLink"), context=context,
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
) )
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return { if result.rejected:
"status": "success", return {
"file_id": created.get("id"), "status": "rejected",
"name": created.get("name"), "message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
"web_view_link": created.get("webViewLink"), }
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
} final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
params: dict[str, Any] = {
"name": final_name,
"mimeType": mime_type,
"fields": "id,name,webViewLink,mimeType",
}
if final_parent_folder_id:
params["parents"] = [final_parent_folder_id]
if final_content:
params["description"] = final_content[:4096]
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_CREATE_FILE",
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
created = result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
if not isinstance(created, dict):
created = {}
else:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.google_drive import GoogleDriveKBSyncService
kb_service = GoogleDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=mime_type,
web_view_link=created.get("webViewLink"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.client import GoogleDriveClient
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured delete_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_google_drive_file( async def delete_google_drive_file(
file_name: str, file_name: str,
@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "Google Drive tool not properly configured. Please contact support.", "message": "Google Drive tool not properly configured. Please contact support.",
} }
try: try:
metadata_service = GoogleDriveToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_trash_context( metadata_service = GoogleDriveToolMetadataService(db_session)
search_space_id, user_id, file_name context = await metadata_service.get_trash_context(
) search_space_id, user_id, file_name
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Drive account %s has expired authentication",
account.get("id"),
) )
return {
"status": "auth_error",
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
file = context["file"] if "error" in context:
file_id = file["file_id"] error_msg = context["error"]
document_id = file.get("document_id") if "not found" in error_msg.lower():
connector_id_from_context = context["account"]["id"] logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
if not file_id: account = context.get("account", {})
return { if account.get("auth_expired"):
"status": "error",
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
logger.info(
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_drive_file_trash",
tool_name="delete_google_drive_file",
params={
"file_id": file_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
credentials=pre_built_creds,
)
try:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning( logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}" "Google Drive account %s has expired authentication",
account.get("id"),
) )
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "auth_error",
"connector_id": connector.id, "message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", "connector_type": "google_drive",
} }
raise
logger.info( file = context["file"]
f"Google Drive file deleted (moved to trash): file_id={final_file_id}" file_id = file["file_id"]
) document_id = file.get("document_id")
connector_id_from_context = context["account"]["id"]
trash_result: dict[str, Any] = { if not file_id:
"status": "success", return {
"file_id": final_file_id, "status": "error",
"message": f"Successfully moved '{file['name']}' to trash.", "message": "File ID is missing from the indexed document. Please re-index the file and try again.",
} }
deleted_from_kb = False logger.info(
if final_delete_from_kb and document_id: f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
try: )
from app.db import Document result = request_approval(
action_type="google_drive_file_trash",
doc_result = await db_session.execute( tool_name="delete_google_drive_file",
select(Document).filter(Document.id == document_id) params={
) "file_id": file_id,
document = doc_result.scalars().first() "connector_id": connector_id_from_context,
if document: "delete_from_kb": delete_from_kb,
await db_session.delete(document) },
await db_session.commit() context=context,
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
) )
return trash_result if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_TRASH_FILE",
params={"file_id": final_file_id},
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
else:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file['name']}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{ {
"create_gmail_draft", "create_gmail_draft",
"update_gmail_draft", "update_gmail_draft",
"create_calendar_event",
"create_notion_page", "create_notion_page",
"create_confluence_page", "create_confluence_page",
"create_google_drive_file", "create_google_drive_file",

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the create_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Jira API
blocking.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_jira_issue( async def create_jira_issue(
project_key: str, project_key: str,
@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'" f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
metadata_service = JiraToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = JiraToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Jira accounts need re-authentication.",
"connector_type": "jira",
}
result = request_approval(
action_type="jira_issue_creation",
tool_name="create_jira_issue",
params={
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_project_key = result.params.get("project_key", project_key)
final_summary = result.params.get("summary", summary)
final_issue_type = result.params.get("issue_type", issue_type)
final_description = result.params.get("description", description)
final_priority = result.params.get("priority", priority)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Issue summary cannot be empty."}
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
return {"status": "error", "message": "No Jira connector found."} return {"status": "error", "message": context["error"]}
actual_connector_id = connector.id
else: accounts = context.get("accounts", [])
result = await db_session.execute( if accounts and all(a.get("auth_expired") for a in accounts):
select(SearchSourceConnector).filter( return {
SearchSourceConnector.id == actual_connector_id, "status": "auth_error",
SearchSourceConnector.search_space_id == search_space_id, "message": "All connected Jira accounts need re-authentication.",
SearchSourceConnector.user_id == user_id, "connector_type": "jira",
SearchSourceConnector.connector_type }
== SearchSourceConnectorType.JIRA_CONNECTOR,
) result = request_approval(
action_type="jira_issue_creation",
tool_name="create_jira_issue",
params={
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
context=context,
) )
connector = result.scalars().first()
if not connector: if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_project_key = result.params.get("project_key", project_key)
final_summary = result.params.get("summary", summary)
final_issue_type = result.params.get("issue_type", issue_type)
final_description = result.params.get("description", description)
final_priority = result.params.get("priority", priority)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return { return {
"status": "error", "status": "error",
"message": "Selected Jira connector is invalid.", "message": "Issue summary cannot be empty.",
} }
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
try: from sqlalchemy.future import select
jira_history = JiraHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
jira_client = await jira_history._get_jira_client()
api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_key = api_result.get("key", "") from app.db import SearchSourceConnector, SearchSourceConnectorType
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
kb_message_suffix = "" actual_connector_id = final_connector_id
try: if actual_connector_id is None:
from app.services.jira import JiraKBSyncService result = await db_session.execute(
select(SearchSourceConnector).filter(
kb_service = JiraKBSyncService(db_session) SearchSourceConnector.search_space_id == search_space_id,
kb_result = await kb_service.sync_after_create( SearchSourceConnector.user_id == user_id,
issue_id=issue_key, SearchSourceConnector.connector_type
issue_identifier=issue_key, == SearchSourceConnectorType.JIRA_CONNECTOR,
issue_title=final_summary, )
description=final_description, )
state="To Do", connector = result.scalars().first()
connector_id=actual_connector_id, if not connector:
search_space_id=search_space_id, return {
user_id=user_id, "status": "error",
) "message": "No Jira connector found.",
if kb_result["status"] == "success": }
kb_message_suffix = " Your knowledge base has also been updated." actual_connector_id = connector.id
else: else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." result = await db_session.execute(
except Exception as kb_err: select(SearchSourceConnector).filter(
logger.warning(f"KB sync after create failed: {kb_err}") SearchSourceConnector.id == actual_connector_id,
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
return { try:
"status": "success", jira_history = JiraHistoryConnector(
"issue_key": issue_key, session=db_session, connector_id=actual_connector_id
"issue_url": issue_url, )
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}", jira_client = await jira_history._get_jira_client()
} api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_key = api_result.get("key", "")
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
kb_message_suffix = ""
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=issue_key,
issue_identifier=issue_key,
issue_title=final_summary,
description=final_description,
state="To Do",
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_key": issue_key,
"issue_url": issue_url,
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the delete_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_jira_issue( async def delete_jira_issue(
issue_title_or_key: str, issue_title_or_key: str,
@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'" f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
metadata_service = JiraToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_deletion_context( metadata_service = JiraToolMetadataService(db_session)
search_space_id, user_id, issue_title_or_key context = await metadata_service.get_deletion_context(
) search_space_id, user_id, issue_title_or_key
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_deletion",
tool_name="delete_jira_issue",
params={
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
) )
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try: if "error" in context:
jira_history = JiraHistoryConnector( error_msg = context["error"]
session=db_session, connector_id=final_connector_id if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_deletion",
tool_name="delete_jira_issue",
params={
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
) )
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key) if result.rejected:
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return { return {
"status": "insufficient_permissions", "status": "rejected",
"connector_id": final_connector_id, "message": "User declined. Do not retry or suggest alternatives.",
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
deleted_from_kb = False final_issue_key = result.params.get("issue_key", issue_key)
if final_delete_from_kb and document_id: final_connector_id = result.params.get(
try: "connector_id", connector_id_from_context
from app.db import Document )
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
doc_result = await db_session.execute( from sqlalchemy.future import select
select(Document).filter(Document.id == document_id)
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
) )
document = doc_result.scalars().first() )
if document: connector = result.scalars().first()
await db_session.delete(document) if not connector:
await db_session.commit() return {
deleted_from_kb = True "status": "error",
except Exception as e: "message": "Selected Jira connector is invalid.",
logger.error(f"Failed to delete document from KB: {e}") }
await db_session.rollback()
message = f"Jira issue {final_issue_key} deleted successfully." try:
if deleted_from_kb: jira_history = JiraHistoryConnector(
message += " Also removed from the knowledge base." session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
return { deleted_from_kb = False
"status": "success", if final_delete_from_kb and document_id:
"issue_key": final_issue_key, try:
"deleted_from_kb": deleted_from_kb, from app.db import Document
"message": message,
} doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Jira issue {final_issue_key} deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"issue_key": final_issue_key,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
"""Factory function to create the update_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured update_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_jira_issue( async def update_jira_issue(
issue_title_or_key: str, issue_title_or_key: str,
@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'" f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."} return {"status": "error", "message": "Jira tool not properly configured."}
try: try:
metadata_service = JiraToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = JiraToolMetadataService(db_session)
search_space_id, user_id, issue_title_or_key context = await metadata_service.get_update_context(
) search_space_id, user_id, issue_title_or_key
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_update",
tool_name="update_jira_issue",
params={
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_summary = result.params.get("new_summary", new_summary)
final_description = result.params.get("new_description", new_description)
final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
) )
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {} if "error" in context:
if final_summary: error_msg = context["error"]
fields["summary"] = final_summary if context.get("auth_expired"):
if final_description is not None: return {
fields["description"] = { "status": "auth_error",
"type": "doc", "message": error_msg,
"version": 1, "connector_id": context.get("connector_id"),
"content": [ "connector_type": "jira",
{
"type": "paragraph",
"content": [{"type": "text", "text": final_description}],
} }
], if "not found" in error_msg.lower():
} return {"status": "not_found", "message": error_msg}
if final_priority: return {"status": "error", "message": error_msg}
fields["priority"] = {"name": final_priority}
if not fields: issue_data = context["issue"]
return {"status": "error", "message": "No changes specified."} issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
try: result = request_approval(
jira_history = JiraHistoryConnector( action_type="jira_issue_update",
session=db_session, connector_id=final_connector_id tool_name="update_jira_issue",
params={
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
context=context,
) )
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread( if result.rejected:
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return { return {
"status": "insufficient_permissions", "status": "rejected",
"connector_id": final_connector_id, "message": "User declined. Do not retry or suggest alternatives.",
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
issue_url = ( final_issue_key = result.params.get("issue_key", issue_key)
f"{jira_history._base_url}/browse/{final_issue_key}" final_summary = result.params.get("new_summary", new_summary)
if jira_history._base_url and final_issue_key final_description = result.params.get(
else "" "new_description", new_description
) )
final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
kb_message_suffix = "" from sqlalchemy.future import select
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session) from app.db import SearchSourceConnector, SearchSourceConnectorType
kb_result = await kb_service.sync_after_update(
document_id=final_document_id, if not final_connector_id:
issue_id=final_issue_key, return {
user_id=user_id, "status": "error",
search_space_id=search_space_id, "message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
) )
if kb_result["status"] == "success": )
kb_message_suffix = ( connector = result.scalars().first()
" Your knowledge base has also been updated." if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {}
if final_summary:
fields["summary"] = final_summary
if final_description is not None:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [
{"type": "text", "text": final_description}
],
}
],
}
if final_priority:
fields["priority"] = {"name": final_priority}
if not fields:
return {"status": "error", "message": "No changes specified."}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_url = (
f"{jira_history._base_url}/browse/{final_issue_key}"
if jira_history._base_url and final_issue_key
else ""
)
kb_message_suffix = ""
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_key,
user_id=user_id,
search_space_id=search_space_id,
) )
else: if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = ( kb_message_suffix = (
" The knowledge base will be updated in the next sync." " The knowledge base will be updated in the next sync."
) )
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return { return {
"status": "success", "status": "success",
"issue_key": final_issue_key, "issue_key": final_issue_key,
"issue_url": issue_url, "issue_url": issue_url,
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}", "message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
} }
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the create_linear_issue tool.
Factory function to create the create_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
Returns: Returns:
Configured create_linear_issue tool Configured create_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def create_linear_issue( async def create_linear_issue(
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
""" """
logger.info(f"create_linear_issue called: title='{title}'") logger.info(f"create_linear_issue called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
} }
try: try:
metadata_service = LinearToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = LinearToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
result = request_approval(
action_type="linear_issue_creation",
tool_name="create_linear_issue",
params={
"title": title,
"description": description,
"team_id": None,
"state_id": None,
"assignee_id": None,
"priority": None,
"label_ids": [],
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Linear issue creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_description = result.params.get("description", description)
final_team_id = result.params.get("team_id")
final_state_id = result.params.get("state_id")
final_assignee_id = result.params.get("assignee_id")
final_priority = result.params.get("priority")
final_label_ids = result.params.get("label_ids") or []
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {"status": "error", "message": "Issue title cannot be empty."}
if not final_team_id:
return {
"status": "error",
"message": "A team must be selected to create an issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
result = request_approval(
action_type="linear_issue_creation",
tool_name="create_linear_issue",
params={
"title": title,
"description": description,
"team_id": None,
"state_id": None,
"assignee_id": None,
"priority": None,
"label_ids": [],
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Linear issue creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_description = result.params.get("description", description)
final_team_id = result.params.get("team_id")
final_state_id = result.params.get("state_id")
final_assignee_id = result.params.get("assignee_id")
final_priority = result.params.get("priority")
final_label_ids = result.params.get("label_ids") or []
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return { return {
"status": "error", "status": "error",
"message": "No Linear connector found. Please connect Linear in your workspace settings.", "message": "Issue title cannot be empty.",
} }
actual_connector_id = connector.id if not final_team_id:
logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return { return {
"status": "error", "status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.", "message": "A team must be selected to create an issue.",
} }
logger.info(f"Validated Linear connector: id={actual_connector_id}")
logger.info( from sqlalchemy.future import select
f"Creating Linear issue with final params: title='{final_title}'"
)
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.create_issue(
team_id=final_team_id,
title=final_title,
description=final_description,
state_id=final_state_id,
assignee_id=final_assignee_id,
priority=final_priority,
label_ids=final_label_ids if final_label_ids else None,
)
if result.get("status") == "error": from app.db import SearchSourceConnector, SearchSourceConnectorType
logger.error(f"Failed to create Linear issue: {result.get('message')}")
return {"status": "error", "message": result.get("message")}
logger.info( actual_connector_id = final_connector_id
f"Linear issue created: {result.get('identifier')} - {result.get('title')}" if actual_connector_id is None:
) result = await db_session.execute(
select(SearchSourceConnector).filter(
kb_message_suffix = "" SearchSourceConnector.search_space_id == search_space_id,
try: SearchSourceConnector.user_id == user_id,
from app.services.linear import LinearKBSyncService SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
kb_service = LinearKBSyncService(db_session) )
kb_result = await kb_service.sync_after_create( )
issue_id=result.get("id"), connector = result.scalars().first()
issue_identifier=result.get("identifier", ""), if not connector:
issue_title=result.get("title", final_title), return {
issue_url=result.get("url"), "status": "error",
description=final_description, "message": "No Linear connector found. Please connect Linear in your workspace settings.",
connector_id=actual_connector_id, }
search_space_id=search_space_id, actual_connector_id = connector.id
user_id=user_id, logger.info(f"Found Linear connector: id={actual_connector_id}")
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else: else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." result = await db_session.execute(
except Exception as kb_err: select(SearchSourceConnector).filter(
logger.warning(f"KB sync after create failed: {kb_err}") SearchSourceConnector.id == actual_connector_id,
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={actual_connector_id}")
return { logger.info(
"status": "success", f"Creating Linear issue with final params: title='{final_title}'"
"issue_id": result.get("id"), )
"identifier": result.get("identifier"), linear_client = LinearConnector(
"url": result.get("url"), session=db_session, connector_id=actual_connector_id
"message": (result.get("message", "") + kb_message_suffix), )
} result = await linear_client.create_issue(
team_id=final_team_id,
title=final_title,
description=final_description,
state_id=final_state_id,
assignee_id=final_assignee_id,
priority=final_priority,
label_ids=final_label_ids if final_label_ids else None,
)
if result.get("status") == "error":
logger.error(
f"Failed to create Linear issue: {result.get('message')}"
)
return {"status": "error", "message": result.get("message")}
logger.info(
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
)
kb_message_suffix = ""
try:
from app.services.linear import LinearKBSyncService
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=result.get("id"),
issue_identifier=result.get("identifier", ""),
issue_title=result.get("title", final_title),
issue_url=result.get("url"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_id": result.get("id"),
"identifier": result.get("identifier"),
"url": result.get("url"),
"message": (result.get("message", "") + kb_message_suffix),
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearToolMetadataService from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the delete_linear_issue tool.
Factory function to create the delete_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
Returns: Returns:
Configured delete_linear_issue tool Configured delete_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def delete_linear_issue( async def delete_linear_issue(
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
} }
try: try:
metadata_service = LinearToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_delete_context( metadata_service = LinearToolMetadataService(db_session)
search_space_id, user_id, issue_ref context = await metadata_service.get_delete_context(
) search_space_id, user_id, issue_ref
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
issue_identifier = context["issue"].get("identifier", "")
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="linear_issue_deletion",
tool_name="delete_linear_issue",
params={
"issue_id": issue_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Linear issue deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
logger.error( error_msg = context["error"]
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
issue_identifier = context["issue"].get("identifier", "")
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="linear_issue_deletion",
tool_name="delete_linear_issue",
params={
"issue_id": issue_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Linear issue deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
) )
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(f"Validated Linear connector: id={actual_connector_id}")
else:
logger.error("No connector found for this issue")
return { return {
"status": "error", "status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.", "message": "No connector found for this issue.",
} }
actual_connector_id = connector.id
logger.info(f"Validated Linear connector: id={actual_connector_id}")
else:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
linear_client = LinearConnector( linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id session=db_session, connector_id=actual_connector_id
) )
result = await linear_client.archive_issue(issue_id=final_issue_id) result = await linear_client.archive_issue(issue_id=final_issue_id)
logger.info( logger.info(
f"archive_issue result: {result.get('status')} - {result.get('message', '')}" f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
) )
deleted_from_kb = False deleted_from_kb = False
if ( if (
result.get("status") == "success" result.get("status") == "success"
and final_delete_from_kb and final_delete_from_kb
and document_id and document_id
): ):
try: try:
from app.db import Document from app.db import Document
doc_result = await db_session.execute( doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id) select(Document).filter(Document.id == document_id)
) )
document = doc_result.scalars().first() document = doc_result.scalars().first()
if document: if document:
await db_session.delete(document) await db_session.delete(document)
await db_session.commit() await db_session.commit()
deleted_from_kb = True deleted_from_kb = True
logger.info( logger.info(
f"Deleted document {document_id} from knowledge base" f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
) )
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
if result.get("status") == "success": if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb result["deleted_from_kb"] = deleted_from_kb
if issue_identifier: if issue_identifier:
result["message"] = ( result["message"] = (
f"Issue {issue_identifier} archived successfully." f"Issue {issue_identifier} archived successfully."
) )
if deleted_from_kb: if deleted_from_kb:
result["message"] = ( result["message"] = (
f"{result.get('message', '')} Also removed from the knowledge base." f"{result.get('message', '')} Also removed from the knowledge base."
) )
return result return result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.db import async_session_maker
from app.services.linear import LinearKBSyncService, LinearToolMetadataService from app.services.linear import LinearKBSyncService, LinearToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
user_id: str | None = None, user_id: str | None = None,
connector_id: int | None = None, connector_id: int | None = None,
): ):
""" """Factory function to create the update_linear_issue tool.
Factory function to create the update_linear_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing the Linear connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Linear connector search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
Returns: Returns:
Configured update_linear_issue tool Configured update_linear_issue tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def update_linear_issue( async def update_linear_issue(
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
""" """
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'") logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Linear tool not properly configured - missing required parameters" "Linear tool not properly configured - missing required parameters"
) )
@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
} }
try: try:
metadata_service = LinearToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = LinearToolMetadataService(db_session)
search_space_id, user_id, issue_ref context = await metadata_service.get_update_context(
) search_space_id, user_id, issue_ref
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for update context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
team = context.get("team", {})
new_state_id = _resolve_state(team, new_state_name)
new_assignee_id = _resolve_assignee(team, new_assignee_email)
new_label_ids = _resolve_labels(team, new_label_names)
logger.info(
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
)
result = request_approval(
action_type="linear_issue_update",
tool_name="update_linear_issue",
params={
"issue_id": issue_id,
"document_id": document_id,
"new_title": new_title,
"new_description": new_description,
"new_state_id": new_state_id,
"new_assignee_id": new_assignee_id,
"new_priority": new_priority,
"new_label_ids": new_label_ids,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Linear issue update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_document_id = result.params.get("document_id", document_id)
final_new_title = result.params.get("new_title", new_title)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_state_id = result.params.get("new_state_id", new_state_id)
final_new_assignee_id = result.params.get(
"new_assignee_id", new_assignee_id
)
final_new_priority = result.params.get("new_priority", new_priority)
final_new_label_ids: list[str] | None = result.params.get(
"new_label_ids", new_label_ids
)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
if not final_connector_id:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
) )
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={final_connector_id}")
logger.info( if "error" in context:
f"Updating Linear issue with final params: issue_id={final_issue_id}" error_msg = context["error"]
) if context.get("auth_expired"):
linear_client = LinearConnector( logger.warning(f"Auth expired for update context: {error_msg}")
session=db_session, connector_id=final_connector_id return {
) "status": "auth_error",
updated_issue = await linear_client.update_issue( "message": error_msg,
issue_id=final_issue_id, "connector_id": context.get("connector_id"),
title=final_new_title, "connector_type": "linear",
description=final_new_description, }
state_id=final_new_state_id, if "not found" in error_msg.lower():
assignee_id=final_new_assignee_id, logger.warning(f"Issue not found: {error_msg}")
priority=final_new_priority, return {"status": "not_found", "message": error_msg}
label_ids=final_new_label_ids, else:
) logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if updated_issue.get("status") == "error": issue_id = context["issue"]["id"]
logger.error( document_id = context["issue"]["document_id"]
f"Failed to update Linear issue: {updated_issue.get('message')}" connector_id_from_context = context.get("workspace", {}).get("id")
)
return {
"status": "error",
"message": updated_issue.get("message"),
}
logger.info( team = context.get("team", {})
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}" new_state_id = _resolve_state(team, new_state_name)
) new_assignee_id = _resolve_assignee(team, new_assignee_email)
new_label_ids = _resolve_labels(team, new_label_names)
if final_document_id is not None:
logger.info( logger.info(
f"Updating knowledge base for document {final_document_id}..." f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
) )
kb_service = LinearKBSyncService(db_session) result = request_approval(
kb_result = await kb_service.sync_after_update( action_type="linear_issue_update",
document_id=final_document_id, tool_name="update_linear_issue",
issue_id=final_issue_id, params={
user_id=user_id, "issue_id": issue_id,
search_space_id=search_space_id, "document_id": document_id,
"new_title": new_title,
"new_description": new_description,
"new_state_id": new_state_id,
"new_assignee_id": new_assignee_id,
"new_priority": new_priority,
"new_label_ids": new_label_ids,
"connector_id": connector_id_from_context,
},
context=context,
) )
if kb_result["status"] == "success":
logger.info(
f"Knowledge base successfully updated for issue {final_issue_id}"
)
kb_message = " Your knowledge base has also been updated."
elif kb_result["status"] == "not_indexed":
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
else:
logger.warning(
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
)
kb_message = " Your knowledge base will be updated in the next scheduled sync."
else:
kb_message = ""
identifier = updated_issue.get("identifier") if result.rejected:
default_msg = f"Issue {identifier} updated successfully." logger.info("Linear issue update rejected by user")
return { return {
"status": "success", "status": "rejected",
"identifier": identifier, "message": "User declined. Do not retry or suggest alternatives.",
"url": updated_issue.get("url"), }
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
} final_issue_id = result.params.get("issue_id", issue_id)
final_document_id = result.params.get("document_id", document_id)
final_new_title = result.params.get("new_title", new_title)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_state_id = result.params.get("new_state_id", new_state_id)
final_new_assignee_id = result.params.get(
"new_assignee_id", new_assignee_id
)
final_new_priority = result.params.get("new_priority", new_priority)
final_new_label_ids: list[str] | None = result.params.get(
"new_label_ids", new_label_ids
)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
if not final_connector_id:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={final_connector_id}")
logger.info(
f"Updating Linear issue with final params: issue_id={final_issue_id}"
)
linear_client = LinearConnector(
session=db_session, connector_id=final_connector_id
)
updated_issue = await linear_client.update_issue(
issue_id=final_issue_id,
title=final_new_title,
description=final_new_description,
state_id=final_new_state_id,
assignee_id=final_new_assignee_id,
priority=final_new_priority,
label_ids=final_new_label_ids,
)
if updated_issue.get("status") == "error":
logger.error(
f"Failed to update Linear issue: {updated_issue.get('message')}"
)
return {
"status": "error",
"message": updated_issue.get("message"),
}
logger.info(
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
)
if final_document_id is not None:
logger.info(
f"Updating knowledge base for document {final_document_id}..."
)
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_id,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
logger.info(
f"Knowledge base successfully updated for issue {final_issue_id}"
)
kb_message = " Your knowledge base has also been updated."
elif kb_result["status"] == "not_indexed":
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
else:
logger.warning(
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
)
kb_message = " Your knowledge base will be updated in the next scheduled sync."
else:
kb_message = ""
identifier = updated_issue.get("identifier")
default_msg = f"Issue {identifier} updated successfully."
return {
"status": "success",
"identifier": identifier,
"url": updated_issue.get("url"),
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_luma_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_luma_event( async def create_luma_event(
name: str, name: str,
@ -40,83 +58,86 @@ def create_create_luma_event_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_luma_connector(
return {"status": "error", "message": "No Luma connector found."} db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
result = request_approval( result = request_approval(
action_type="luma_create_event", action_type="luma_create_event",
tool_name="create_luma_event", tool_name="create_luma_event",
params={ params={
"name": name, "name": name,
"start_at": start_at, "start_at": start_at,
"end_at": end_at, "end_at": end_at,
"description": description, "description": description,
"timezone": timezone, "timezone": timezone,
}, },
context={"connector_id": connector.id}, context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Event was not created.",
}
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
) )
if resp.status_code == 401: if result.rejected:
return { return {
"status": "auth_error", "status": "rejected",
"message": "Luma API key is invalid.", "message": "User declined. Event was not created.",
"connector_type": "luma", }
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Luma Plus subscription required to create events via API.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}{resp.text[:200]}",
}
data = resp.json() final_name = result.params.get("name", name)
event_id = data.get("api_id") or data.get("event", {}).get("api_id") final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
return { api_key = get_api_key(connector)
"status": "success", headers = luma_headers(api_key)
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.", body: dict[str, Any] = {
} "name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Luma Plus subscription required to create events via API.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}{resp.text[:200]}",
}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_luma_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_luma_events tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_luma_events( async def list_luma_events(
max_results: int = 25, max_results: int = 25,
@ -28,77 +47,80 @@ def create_list_luma_events_tool(
Dictionary with status and a list of events including Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url. event_id, name, start_at, end_at, location, url.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50) max_results = min(max_results, 50)
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_luma_connector(
return {"status": "error", "message": "No Luma connector found."} db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector) api_key = get_api_key(connector)
headers = luma_headers(api_key) headers = luma_headers(api_key)
all_entries: list[dict] = [] all_entries: list[dict] = []
cursor = None cursor = None
async with httpx.AsyncClient(timeout=20.0) as client: async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results: while len(all_entries) < max_results:
params: dict[str, Any] = { params: dict[str, Any] = {
"limit": min(100, max_results - len(all_entries)) "limit": min(100, max_results - len(all_entries))
} }
if cursor: if cursor:
params["cursor"] = cursor params["cursor"] = cursor
resp = await client.get( resp = await client.get(
f"{LUMA_API}/calendar/list-events", f"{LUMA_API}/calendar/list-events",
headers=headers, headers=headers,
params=params, params=params,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append(
{
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
}
) )
if resp.status_code == 401: return {"status": "success", "events": events, "total": len(events)}
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append(
{
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
}
)
return {"status": "success", "events": events, "total": len(events)}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_luma_event tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_luma_event(event_id: str) -> dict[str, Any]: async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event. """Read detailed information about a specific Luma event.
@ -26,60 +45,63 @@ def create_read_luma_event_tool(
Dictionary with status and full event details including Dictionary with status and full event details including
description, attendees count, meeting URL. description, attendees count, meeting URL.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."} return {"status": "error", "message": "Luma tool not properly configured."}
try: try:
connector = await get_luma_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_luma_connector(
return {"status": "error", "message": "No Luma connector found."} db_session, search_space_id, user_id
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
) )
if not connector:
return {"status": "error", "message": "No Luma connector found."}
if resp.status_code == 401: api_key = get_api_key(connector)
return { headers = luma_headers(api_key)
"status": "auth_error",
"message": "Luma API key is invalid.", async with httpx.AsyncClient(timeout=15.0) as client:
"connector_type": "luma", resp = await client.get(
} f"{LUMA_API}/events/{event_id}",
if resp.status_code == 404: headers=headers,
return { )
"status": "not_found",
"message": f"Event '{event_id}' not found.", if resp.status_code == 401:
} return {
if resp.status_code != 200: "status": "auth_error",
return { "message": "Luma API key is invalid.",
"status": "error", "connector_type": "luma",
"message": f"Luma API error: {resp.status_code}", }
if resp.status_code == 404:
return {
"status": "not_found",
"message": f"Event '{event_id}' not found.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
} }
data = resp.json() return {"status": "success", "event": event_detail}
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
return {"status": "success", "event": event_detail}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
""" """
Factory function to create the create_notion_page tool. Factory function to create the create_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Notion API
blocking.
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
Returns: Returns:
Configured create_notion_page tool Configured create_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def create_notion_page( async def create_notion_page(
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
""" """
logger.info(f"create_notion_page called: title='{title}'") logger.info(f"create_notion_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -77,154 +88,157 @@ def create_create_notion_page_tool(
} }
try: try:
metadata_service = NotionToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_creation_context( metadata_service = NotionToolMetadataService(db_session)
search_space_id, user_id context = await metadata_service.get_creation_context(
) search_space_id, user_id
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {
"status": "error",
"message": context["error"],
}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Notion accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
logger.info(f"Requesting approval for creating Notion page: '{title}'")
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={
"title": title,
"content": content,
"parent_page_id": None,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Notion page creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_content = result.params.get("content", content)
final_parent_page_id = result.params.get("parent_page_id")
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
"message": "Page title cannot be empty. Please provide a valid title.",
}
logger.info(
f"Creating Notion page with final params: title='{final_title}'"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
) )
connector = result.scalars().first()
if not connector: if "error" in context:
logger.warning(
f"No Notion connector found for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Notion connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error( logger.error(
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}" f"Failed to fetch creation context: {context['error']}"
) )
return { return {
"status": "error", "status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", "message": context["error"],
} }
logger.info(f"Validated Notion connector: id={actual_connector_id}")
notion_connector = NotionHistoryConnector( accounts = context.get("accounts", [])
session=db_session, if accounts and all(a.get("auth_expired") for a in accounts):
connector_id=actual_connector_id, logger.warning("All Notion accounts have expired authentication")
) return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
result = await notion_connector.create_page( logger.info(f"Requesting approval for creating Notion page: '{title}'")
title=final_title, result = request_approval(
content=final_content, action_type="notion_page_creation",
parent_page_id=final_parent_page_id, tool_name="create_notion_page",
) params={
logger.info( "title": title,
f"create_page result: {result.get('status')} - {result.get('message', '')}" "content": content,
) "parent_page_id": None,
"connector_id": connector_id,
},
context=context,
)
if result.get("status") == "success": if result.rejected:
kb_message_suffix = "" logger.info("Notion page creation rejected by user")
try: return {
from app.services.notion import NotionKBSyncService "status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
kb_service = NotionKBSyncService(db_session) final_title = result.params.get("title", title)
kb_result = await kb_service.sync_after_create( final_content = result.params.get("content", content)
page_id=result.get("page_id"), final_parent_page_id = result.params.get("parent_page_id")
page_title=result.get("title", final_title), final_connector_id = result.params.get("connector_id", connector_id)
page_url=result.get("url"),
content=final_content, if not final_title or not final_title.strip():
connector_id=actual_connector_id, logger.error("Title is empty or contains only whitespace")
search_space_id=search_space_id, return {
user_id=user_id, "status": "error",
) "message": "Page title cannot be empty. Please provide a valid title.",
if kb_result["status"] == "success": }
kb_message_suffix = (
" Your knowledge base has also been updated." logger.info(
f"Creating Notion page with final params: title='{final_title}'"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
) )
else: )
connector = result.scalars().first()
if not connector:
logger.warning(
f"No Notion connector found for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Notion connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
logger.info(f"Validated Notion connector: id={actual_connector_id}")
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.create_page(
title=final_title,
content=final_content,
parent_page_id=final_parent_page_id,
)
logger.info(
f"create_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success":
kb_message_suffix = ""
try:
from app.services.notion import NotionKBSyncService
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=result.get("page_id"),
page_title=result.get("title", final_title),
page_url=result.get("url"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync." kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
result["message"] = result.get("message", "") + kb_message_suffix result["message"] = result.get("message", "") + kb_message_suffix
return result return result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
""" """
Factory function to create the delete_notion_page tool. Factory function to create the delete_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
Returns: Returns:
Configured delete_notion_page tool Configured delete_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def delete_notion_page( async def delete_notion_page(
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
} }
try: try:
# Get page context (page_id, account, title) from indexed data async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session) # Get page context (page_id, account, title) from indexed data
context = await metadata_service.get_delete_context( metadata_service = NotionToolMetadataService(db_session)
search_space_id, user_id, page_title context = await metadata_service.get_delete_context(
) search_space_id, user_id, page_title
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
) )
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id") if "error" in context:
connector_id_from_context = account.get("id") error_msg = context["error"]
document_id = context.get("document_id") # Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.info( logger.warning(f"Page not found: {error_msg}")
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})" return {
) "status": "not_found",
"message": error_msg,
result = request_approval( }
action_type="notion_page_deletion",
tool_name="delete_notion_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Notion page deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
# Validate the connector
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
# Create connector instance
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
# Delete the page from Notion
result = await notion_connector.delete_page(page_id=final_page_id)
logger.info(
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
)
# If deletion was successful and user wants to delete from KB
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from sqlalchemy.future import select
from app.db import Document
# Get the document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else: else:
logger.warning(f"Document {document_id} not found in KB") logger.error(f"Failed to fetch delete context: {error_msg}")
except Exception as e: return {
logger.error(f"Failed to delete document from KB: {e}") "status": "error",
await db_session.rollback() "message": error_msg,
result["warning"] = ( }
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
)
# Update result with KB deletion status account = context.get("account", {})
if result.get("status") == "success": if account.get("auth_expired"):
result["deleted_from_kb"] = deleted_from_kb logger.warning(
if deleted_from_kb: "Notion account %s has expired authentication",
result["message"] = ( account.get("id"),
f"{result.get('message', '')} (also removed from knowledge base)"
) )
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
return result page_id = context.get("page_id")
connector_id_from_context = account.get("id")
document_id = context.get("document_id")
logger.info(
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="notion_page_deletion",
tool_name="delete_notion_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Notion page deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
# Validate the connector
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
# Create connector instance
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
# Delete the page from Notion
result = await notion_connector.delete_page(page_id=final_page_id)
logger.info(
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
)
# If deletion was successful and user wants to delete from KB
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from sqlalchemy.future import select
from app.db import Document
# Get the document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
)
# Update result with KB deletion status
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} (also removed from knowledge base)"
)
return result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
""" """
Factory function to create the update_notion_page tool. Factory function to create the update_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache (see
``create_create_notion_page_tool`` for the full rationale).
Args: Args:
db_session: Database session for accessing Notion connector db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known) connector_id: Optional specific connector ID (if known)
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
Returns: Returns:
Configured update_notion_page tool Configured update_notion_page tool
""" """
del db_session # per-call session — see docstring
@tool @tool
async def update_notion_page( async def update_notion_page(
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
logger.error( logger.error(
"Notion tool not properly configured - missing required parameters" "Notion tool not properly configured - missing required parameters"
) )
@ -88,152 +96,155 @@ def create_update_notion_page_tool(
} }
try: try:
metadata_service = NotionToolMetadataService(db_session) async with async_session_maker() as db_session:
context = await metadata_service.get_update_context( metadata_service = NotionToolMetadataService(db_session)
search_space_id, user_id, page_title context = await metadata_service.get_update_context(
) search_space_id, user_id, page_title
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
logger.info(
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
)
result = request_approval(
action_type="notion_page_update",
tool_name="update_notion_page",
params={
"page_id": page_id,
"content": content,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Notion page update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_content = result.params.get("content", content)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
logger.info(
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.update_page(
page_id=final_page_id,
content=final_content,
)
logger.info(
f"update_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService
logger.info(f"Updating knowledge base for document {document_id}...")
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
appended_content=final_content,
user_id=user_id,
search_space_id=search_space_id,
appended_block_ids=result.get("appended_block_ids"),
) )
if kb_result["status"] == "success": if "error" in context:
result["message"] = ( error_msg = context["error"]
f"{result['message']}. Your knowledge base has also been updated." # Check if it's a "not found" error (softer handling for LLM)
) if "not found" in error_msg.lower():
logger.info( logger.warning(f"Page not found: {error_msg}")
f"Knowledge base successfully updated for page {final_page_id}" return {
) "status": "not_found",
elif kb_result["status"] == "not_indexed": "message": error_msg,
result["message"] = ( }
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync." else:
) logger.error(f"Failed to fetch update context: {error_msg}")
else: return {
result["message"] = ( "status": "error",
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync." "message": error_msg,
) }
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning( logger.warning(
f"KB update failed for page {final_page_id}: {kb_result['message']}" "Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
logger.info(
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
)
result = request_approval(
action_type="notion_page_update",
tool_name="update_notion_page",
params={
"page_id": page_id,
"content": content,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Notion page update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_content = result.params.get("content", content)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
logger.info(
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.update_page(
page_id=final_page_id,
content=final_content,
)
logger.info(
f"update_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService
logger.info(
f"Updating knowledge base for document {document_id}..."
)
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
appended_content=final_content,
user_id=user_id,
search_space_id=search_space_id,
appended_block_ids=result.get("appended_block_ids"),
) )
return result if kb_result["status"] == "success":
result["message"] = (
f"{result['message']}. Your knowledge base has also been updated."
)
logger.info(
f"Knowledge base successfully updated for page {final_page_id}"
)
elif kb_result["status"] == "not_indexed":
result["message"] = (
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
)
else:
result["message"] = (
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
)
logger.warning(
f"KB update failed for page {final_page_id}: {kb_result['message']}"
)
return result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -10,7 +10,7 @@ from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient from app.connectors.onedrive.client import OneDriveClient
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the create_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def create_onedrive_file( async def create_onedrive_file(
name: str, name: str,
@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
""" """
logger.info(f"create_onedrive_file called: name='{name}'") logger.info(f"create_onedrive_file called: name='{name}'")
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "OneDrive tool not properly configured.", "message": "OneDrive tool not properly configured.",
} }
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected OneDrive accounts need re-authentication.",
"connector_type": "onedrive",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = OneDriveClient(session=db_session, connector_id=cid)
items, err = await client.list_children("root")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{"folder_id": item["id"], "name": item["name"]}
for item in items
if item.get("folder") is not None
and item.get("id")
and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s", cid, exc_info=True
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
}
result = request_approval(
action_type="onedrive_file_creation",
tool_name="create_onedrive_file",
params={
"name": name,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_docx_extension(final_name)
if final_connector_id is not None:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR, == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
) )
) )
connector = result.scalars().first() connectors = result.scalars().all()
else:
connector = connectors[0]
if not connector: if not connectors:
return { return {
"status": "error", "status": "error",
"message": "Selected OneDrive connector is invalid.", "message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected OneDrive accounts need re-authentication.",
"connector_type": "onedrive",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = OneDriveClient(session=db_session, connector_id=cid)
items, err = await client.list_children("root")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{"folder_id": item["id"], "name": item["name"]}
for item in items
if item.get("folder") is not None
and item.get("id")
and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
} }
docx_bytes = _markdown_to_docx(final_content or "") result = request_approval(
action_type="onedrive_file_creation",
client = OneDriveClient(session=db_session, connector_id=connector.id) tool_name="create_onedrive_file",
created = await client.create_file( params={
name=final_name, "name": name,
parent_id=final_parent_folder_id, "content": content,
content=docx_bytes, "connector_id": None,
mime_type=DOCX_MIME, "parent_folder_id": None,
) },
context=context,
logger.info(
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.onedrive import OneDriveKBSyncService
kb_service = OneDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=DOCX_MIME,
web_url=created.get("webUrl"),
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
) )
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return { if result.rejected:
"status": "success", return {
"file_id": created.get("id"), "status": "rejected",
"name": created.get("name"), "message": "User declined. Do not retry or suggest alternatives.",
"web_url": created.get("webUrl"), }
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
} final_name = result.params.get("name", name)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_docx_extension(final_name)
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
if not connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid.",
}
docx_bytes = _markdown_to_docx(final_content or "")
client = OneDriveClient(session=db_session, connector_id=connector.id)
created = await client.create_file(
name=final_name,
parent_id=final_parent_folder_id,
content=docx_bytes,
mime_type=DOCX_MIME,
)
logger.info(
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.onedrive import OneDriveKBSyncService
kb_service = OneDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=DOCX_MIME,
web_url=created.get("webUrl"),
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_url": created.get("webUrl"),
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -13,6 +13,7 @@ from app.db import (
DocumentType, DocumentType,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
async_session_maker,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the delete_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def delete_onedrive_file( async def delete_onedrive_file(
file_name: str, file_name: str,
@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
) )
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return { return {
"status": "error", "status": "error",
"message": "OneDrive tool not properly configured.", "message": "OneDrive tool not properly configured.",
} }
try: try:
doc_result = await db_session.execute( async with async_session_maker() as db_session:
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
doc_result = await db_session.execute( doc_result = await db_session.execute(
select(Document) select(Document)
.join( .join(
@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
and_( and_(
Document.search_space_id == search_space_id, Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE, Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower( func.lower(Document.title) == func.lower(file_name),
cast(
Document.document_metadata["onedrive_file_name"],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
) )
) )
@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
) )
document = doc_result.scalars().first() document = doc_result.scalars().first()
if not document: if not document:
return { doc_result = await db_session.execute(
"status": "not_found", select(Document)
"message": ( .join(
f"File '{file_name}' not found in your indexed OneDrive files. " SearchSourceConnector,
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, " Document.connector_id == SearchSourceConnector.id,
"or (3) the file name is different." )
), .filter(
} and_(
Document.search_space_id == search_space_id,
if not document.connector_id: Document.document_type == DocumentType.ONEDRIVE_FILE,
return { func.lower(
"status": "error", cast(
"message": "Document has no associated connector.", Document.document_metadata[
} "onedrive_file_name"
],
meta = document.document_metadata or {} String,
file_id = meta.get("onedrive_file_id") )
document_id = document.id )
== func.lower(file_name),
if not file_id: SearchSourceConnector.user_id == user_id,
return { )
"status": "error", )
"message": "File ID is missing. Please re-index the file.", .order_by(Document.updated_at.desc().nullslast())
} .limit(1)
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
) )
) document = doc_result.scalars().first()
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "OneDrive connector not found or access denied.",
}
cfg = connector.config or {} if not document:
if cfg.get("auth_expired"): return {
return { "status": "not_found",
"status": "auth_error", "message": (
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.", f"File '{file_name}' not found in your indexed OneDrive files. "
"connector_type": "onedrive", "This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
} "or (3) the file name is different."
),
}
context = { if not document.connector_id:
"file": { return {
"file_id": file_id, "status": "error",
"name": file_name, "message": "Document has no associated connector.",
"document_id": document_id, }
"web_url": meta.get("web_url"),
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
result = request_approval( meta = document.document_metadata or {}
action_type="onedrive_file_trash", file_id = meta.get("onedrive_file_id")
tool_name="delete_onedrive_file", document_id = document.id
params={
"file_id": file_id,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected: if not file_id:
return { return {
"status": "rejected", "status": "error",
"message": "User declined. Do not retry or suggest alternatives.", "message": "File ID is missing. Please re-index the file.",
} }
final_file_id = result.params.get("file_id", file_id) conn_result = await db_session.execute(
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
and_( and_(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
) )
) )
) )
validated_connector = result.scalars().first() connector = conn_result.scalars().first()
if not validated_connector: if not connector:
return { return {
"status": "error", "status": "error",
"message": "Selected OneDrive connector is invalid or has been disconnected.", "message": "OneDrive connector not found or access denied.",
} }
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info( cfg = connector.config or {}
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" if cfg.get("auth_expired"):
) return {
"status": "auth_error",
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "onedrive",
}
client = OneDriveClient( context = {
session=db_session, connector_id=actual_connector_id "file": {
) "file_id": file_id,
await client.trash_file(final_file_id) "name": file_name,
"document_id": document_id,
"web_url": meta.get("web_url"),
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
logger.info( result = request_approval(
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" action_type="onedrive_file_trash",
) tool_name="delete_onedrive_file",
params={
trash_result: dict[str, Any] = { "file_id": file_id,
"status": "success", "connector_id": connector.id,
"file_id": final_file_id, "delete_from_kb": delete_from_kb,
"message": f"Successfully moved '{file_name}' to the recycle bin.", },
} context=context,
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
) )
return trash_result if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
)
client = OneDriveClient(
session=db_session, connector_id=actual_connector_id
)
await client.trash_file(final_file_id)
logger.info(
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file_name}' to the recycle bin.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -824,13 +824,22 @@ async def build_tools_async(
"""Async version of build_tools that also loads MCP tools from database. """Async version of build_tools that also loads MCP tools from database.
Design Note: Design Note:
This function exists because MCP tools require database queries to load user configs, This function exists because MCP tools require database queries to load
while built-in tools are created synchronously from static code. user configs, while built-in tools are created synchronously from static
code.
Alternative: We could make build_tools() itself async and always query the database, Alternative: We could make build_tools() itself async and always query
but that would force async everywhere even when only using built-in tools. The current the database, but that would force async everywhere even when only using
design keeps the simple case (static tools only) synchronous while supporting dynamic built-in tools. The current design keeps the simple case (static tools
database-loaded tools through this async wrapper. only) synchronous while supporting dynamic database-loaded tools through
this async wrapper.
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
the event loop) are kicked off concurrently. Cold-path savings are
bounded by the slower of the two typically MCP at ~200ms-1.7s
so the parallelization recovers the ~50-200ms previously spent
serially on built-in construction.
Args: Args:
dependencies: Dict containing all possible dependencies dependencies: Dict containing all possible dependencies
@ -843,33 +852,70 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools. List of configured tool instances ready for the agent, including MCP tools.
""" """
import asyncio
import time import time
_perf_log = logging.getLogger("surfsense.perf") _perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG) _perf_log.setLevel(logging.DEBUG)
can_load_mcp = (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
)
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
# function over its inputs — safe to thread-shift.
_t0 = time.perf_counter() _t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) builtin_task = asyncio.create_task(
asyncio.to_thread(
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
)
)
mcp_task: asyncio.Task | None = None
if can_load_mcp:
mcp_task = asyncio.create_task(
load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
)
# Surface failures from each task independently so a flaky MCP
# endpoint never poisons built-in tool registration. ``return_exceptions``
# gives us per-task exceptions instead of dropping the second result
# when the first raises.
if mcp_task is not None:
builtin_result, mcp_result = await asyncio.gather(
builtin_task, mcp_task, return_exceptions=True
)
else:
builtin_result = await builtin_task
mcp_result = None
if isinstance(builtin_result, BaseException):
raise builtin_result # built-in registration failure is non-recoverable
tools: list[BaseTool] = builtin_result
_perf_log.info( _perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)", "[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
len(tools), len(tools),
) )
# Load MCP tools if requested and dependencies are available if mcp_task is not None:
if ( if isinstance(mcp_result, BaseException):
include_mcp_tools # ``return_exceptions=True`` captures the exception out-of-band,
and "db_session" in dependencies # so ``sys.exc_info()`` is empty here. Pass the captured
and "search_space_id" in dependencies # exception via ``exc_info=`` to get a real traceback.
): logging.error(
try: "Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
) )
else:
mcp_tools = mcp_result or []
_perf_log.info( _perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)", "[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
len(mcp_tools), len(mcp_tools),
) )
@ -879,8 +925,6 @@ async def build_tools_async(
len(mcp_tools), len(mcp_tools),
[t.name for t in mcp_tools], [t.name for t in mcp_tools],
) )
except Exception as e:
logging.exception("Failed to load MCP tools: %s", e)
logging.info( logging.info(
"Total tools for agent: %d%s", "Total tools for agent: %d%s",

View file

@ -15,7 +15,7 @@ from langchain_core.tools import tool
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text from app.utils.document_converters import embed_text
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
""" """
Factory function to create the search_surfsense_docs tool. Factory function to create the search_surfsense_docs tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args: Args:
db_session: Database session for executing queries db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns: Returns:
A configured tool function for searching Surfsense documentation A configured tool function for searching Surfsense documentation
""" """
del db_session # per-call session — see docstring
@tool @tool
async def search_surfsense_docs(query: str, top_k: int = 10) -> str: async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
Returns: Returns:
Relevant documentation content formatted with chunk IDs for citations Relevant documentation content formatted with chunk IDs for citations
""" """
return await search_surfsense_docs_async( async with async_session_maker() as db_session:
query=query, return await search_surfsense_docs_async(
db_session=db_session, query=query,
top_k=top_k, db_session=db_session,
) top_k=top_k,
)
return search_surfsense_docs return search_surfsense_docs

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the list_teams_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_teams_channels tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def list_teams_channels() -> dict[str, Any]: async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to. """List all Microsoft Teams and their channels the user has access to.
@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
Dictionary with status and a list of teams, each containing Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name). team_id, team_name, and a list of channels (id, name).
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_teams_connector(
return {"status": "error", "message": "No Teams connector found."} db_session, search_space_id, user_id
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(
f"{GRAPH_API}/me/joinedTeams", headers=headers
) )
if not connector:
return {"status": "error", "message": "No Teams connector found."}
if teams_resp.status_code == 401: token = await get_access_token(db_session, connector)
return { headers = {"Authorization": f"Bearer {token}"}
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if teams_resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {teams_resp.status_code}",
}
teams_data = teams_resp.json().get("value", []) async with httpx.AsyncClient(timeout=20.0) as client:
result_teams = [] teams_resp = await client.get(
f"{GRAPH_API}/me/joinedTeams", headers=headers
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append(
{
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
}
) )
return { if teams_resp.status_code == 401:
"status": "success", return {
"teams": result_teams, "status": "auth_error",
"total_teams": len(result_teams), "message": "Teams token expired. Please re-authenticate.",
} "connector_type": "teams",
}
if teams_resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {teams_resp.status_code}",
}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append(
{
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
}
)
return {
"status": "success",
"teams": result_teams,
"total_teams": len(result_teams),
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -5,6 +5,8 @@ import httpx
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the read_teams_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_teams_messages tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def read_teams_messages( async def read_teams_messages(
team_id: str, team_id: str,
@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
Dictionary with status and a list of messages including Dictionary with status and a list of messages including
id, sender, content, timestamp. id, sender, content, timestamp.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50) limit = min(limit, 50)
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_teams_connector(
return {"status": "error", "message": "No Teams connector found."} db_session, search_space_id, user_id
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
) )
if not connector:
return {"status": "error", "message": "No Teams connector found."}
if resp.status_code == 401: token = await get_access_token(db_session, connector)
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Insufficient permissions to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}",
}
raw_msgs = resp.json().get("value", []) async with httpx.AsyncClient(timeout=20.0) as client:
messages = [] resp = await client.get(
for m in raw_msgs: f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
sender = m.get("from", {}) headers={"Authorization": f"Bearer {token}"},
user_info = sender.get("user", {}) if sender else {} params={"$top": limit},
body = m.get("body", {}) )
messages.append(
{ if resp.status_code == 401:
"id": m.get("id"), return {
"sender": user_info.get("displayName", "Unknown"), "status": "auth_error",
"content": body.get("content", ""), "message": "Teams token expired. Please re-authenticate.",
"content_type": body.get("contentType", "text"), "connector_type": "teams",
"timestamp": m.get("createdDateTime", ""), }
if resp.status_code == 403:
return {
"status": "error",
"message": "Insufficient permissions to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}",
} }
)
return { raw_msgs = resp.json().get("value", [])
"status": "success", messages = []
"team_id": team_id, for m in raw_msgs:
"channel_id": channel_id, sender = m.get("from", {})
"messages": messages, user_info = sender.get("user", {}) if sender else {}
"total": len(messages), body = m.get("body", {})
} messages.append(
{
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
}
)
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -6,6 +6,7 @@ from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector from ._auth import GRAPH_API, get_access_token, get_teams_connector
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
search_space_id: int | None = None, search_space_id: int | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
"""
Factory function to create the send_teams_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_teams_message tool
"""
del db_session # per-call session — see docstring
@tool @tool
async def send_teams_message( async def send_teams_message(
team_id: str, team_id: str,
@ -39,70 +57,73 @@ def create_send_teams_message_tool(
IMPORTANT: IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry. - If status is "rejected", the user explicitly declined. Do NOT retry.
""" """
if db_session is None or search_space_id is None or user_id is None: if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."} return {"status": "error", "message": "Teams tool not properly configured."}
try: try:
connector = await get_teams_connector(db_session, search_space_id, user_id) async with async_session_maker() as db_session:
if not connector: connector = await get_teams_connector(
return {"status": "error", "message": "No Teams connector found."} db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
result = request_approval( result = request_approval(
action_type="teams_send_message", action_type="teams_send_message",
tool_name="send_teams_message", tool_name="send_teams_message",
params={ params={
"team_id": team_id, "team_id": team_id,
"channel_id": channel_id, "channel_id": channel_id,
"content": content, "content": content,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}, },
json={"body": {"content": final_content}}, context={"connector_id": connector.id},
) )
if resp.status_code == 401: if result.rejected:
return { return {
"status": "auth_error", "status": "rejected",
"message": "Teams token expired. Please re-authenticate.", "message": "User declined. Message was not sent.",
"connector_type": "teams", }
}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}{resp.text[:200]}",
}
msg_data = resp.json() final_content = result.params.get("content", content)
return { final_team = result.params.get("team_id", team_id)
"status": "success", final_channel = result.params.get("channel_id", channel_id)
"message_id": msg_data.get("id"),
"message": "Message sent to Teams channel.", token = await get_access_token(db_session, connector)
}
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={"body": {"content": final_content}},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}{resp.text[:200]}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": "Message sent to Teams channel.",
}
except Exception as e: except Exception as e:
from langgraph.errors import GraphInterrupt from langgraph.errors import GraphInterrupt

View file

@ -26,7 +26,7 @@ from langchain_core.tools import tool
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSpace, User from app.db import SearchSpace, User, async_session_maker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -295,6 +295,25 @@ def create_update_memory_tool(
db_session: AsyncSession, db_session: AsyncSession,
llm: Any | None = None, llm: Any | None = None,
): ):
"""Factory function to create the user-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
user_id: ID of the user whose memory document is being updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the user-memory scope.
"""
del db_session # per-call session — see docstring
uid = UUID(user_id) if isinstance(user_id, str) else user_id uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool @tool
@ -311,26 +330,26 @@ def create_update_memory_tool(
updated_memory: The FULL updated markdown document (not a diff). updated_memory: The FULL updated markdown document (not a diff).
""" """
try: try:
result = await db_session.execute(select(User).where(User.id == uid)) async with async_session_maker() as db_session:
user = result.scalars().first() result = await db_session.execute(select(User).where(User.id == uid))
if not user: user = result.scalars().first()
return {"status": "error", "message": "User not found."} if not user:
return {"status": "error", "message": "User not found."}
old_memory = user.memory_md old_memory = user.memory_md
return await _save_memory( return await _save_memory(
updated_memory=updated_memory, updated_memory=updated_memory,
old_memory=old_memory, old_memory=old_memory,
llm=llm, llm=llm,
apply_fn=lambda content: setattr(user, "memory_md", content), apply_fn=lambda content: setattr(user, "memory_md", content),
commit_fn=db_session.commit, commit_fn=db_session.commit,
rollback_fn=db_session.rollback, rollback_fn=db_session.rollback,
label="memory", label="memory",
scope="user", scope="user",
) )
except Exception as e: except Exception as e:
logger.exception("Failed to update user memory: %s", e) logger.exception("Failed to update user memory: %s", e)
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"message": f"Failed to update memory: {e}", "message": f"Failed to update memory: {e}",
@ -344,6 +363,27 @@ def create_update_team_memory_tool(
db_session: AsyncSession, db_session: AsyncSession,
llm: Any | None = None, llm: Any | None = None,
): ):
"""Factory function to create the team-memory update tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
The session's bound ``commit``/``rollback`` methods are captured at
call time, after ``async with`` has bound ``db_session`` locally.
Args:
search_space_id: ID of the search space whose team memory is being
updated.
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
llm: Optional LLM for the forced-rewrite path.
Returns:
Configured update_memory tool for the team-memory scope.
"""
del db_session # per-call session — see docstring
@tool @tool
async def update_memory(updated_memory: str) -> dict[str, Any]: async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space. """Update the team's shared memory document for this search space.
@ -359,28 +399,30 @@ def create_update_team_memory_tool(
updated_memory: The FULL updated markdown document (not a diff). updated_memory: The FULL updated markdown document (not a diff).
""" """
try: try:
result = await db_session.execute( async with async_session_maker() as db_session:
select(SearchSpace).where(SearchSpace.id == search_space_id) result = await db_session.execute(
) select(SearchSpace).where(SearchSpace.id == search_space_id)
space = result.scalars().first() )
if not space: space = result.scalars().first()
return {"status": "error", "message": "Search space not found."} if not space:
return {"status": "error", "message": "Search space not found."}
old_memory = space.shared_memory_md old_memory = space.shared_memory_md
return await _save_memory( return await _save_memory(
updated_memory=updated_memory, updated_memory=updated_memory,
old_memory=old_memory, old_memory=old_memory,
llm=llm, llm=llm,
apply_fn=lambda content: setattr(space, "shared_memory_md", content), apply_fn=lambda content: setattr(
commit_fn=db_session.commit, space, "shared_memory_md", content
rollback_fn=db_session.rollback, ),
label="team memory", commit_fn=db_session.commit,
scope="team", rollback_fn=db_session.rollback,
) label="team memory",
scope="team",
)
except Exception as e: except Exception as e:
logger.exception("Failed to update team memory: %s", e) logger.exception("Failed to update team memory: %s", e)
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"message": f"Failed to update team memory: {e}", "message": f"Failed to update team memory: {e}",

View file

@ -31,6 +31,7 @@ from app.config import (
initialize_image_gen_router, initialize_image_gen_router,
initialize_llm_router, initialize_llm_router,
initialize_openrouter_integration, initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router, initialize_vision_llm_router,
) )
from app.db import User, create_db_and_tables, get_async_session from app.db import User, create_db_and_tables, get_async_session
@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
OpenRouterIntegrationService.get_instance().stop_background_refresh() OpenRouterIntegrationService.get_instance().stop_background_refresh()
async def _warm_agent_jit_caches() -> None:
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
Why
----
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
generation chain takes 1.5-2 seconds of pure CPU on first invocation
inside any Python process: the graph compiler builds reducers,
Pydantic v2 generates and JITs validator schemas, deepagents
eagerly compiles its general-purpose subagent, etc. Subsequent
compiles in the same process pay only ~50% of that cost (the lazy
JIT bits are cached in module-level dicts).
Doing one throwaway compile during ``lifespan`` startup pre-pays
that cost so the *first real request* doesn't. We do NOT prime
:mod:`agent_cache` because the cache key requires real
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. the
throwaway agent is genuinely thrown away and immediately collected.
Safety
------
* No DB access. We construct a stub LLM (no real keys), pass an
empty tools list, and pass ``checkpointer=None`` so we never
touch Postgres.
* Bounded by ``asyncio.wait_for`` so a hang here can never block
worker startup. On any failure, we log + swallow the worst
case is the first real request pays the full cold cost (i.e.
pre-warmup behaviour).
"""
import time as _time
logger = logging.getLogger(__name__)
t0 = _time.perf_counter()
try:
from langchain.agents import create_agent
from langchain.agents.middleware import (
ModelCallLimitMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
)
from langchain_core.tools import tool
from app.agents.new_chat.context import SurfSenseContextSchema
# Minimal LLM stub. ``FakeListChatModel`` satisfies
# ``BaseChatModel`` without any network or auth — perfect for
# exercising the compile path without side effects.
stub_llm = FakeListChatModel(responses=["warmup-response"])
# Two trivial tools with arg + return schemas — exercises the
# Pydantic v2 schema JIT path. Without at least one tool the
# graph compile skips the tool-loop bytecode generation that
# accounts for ~30-50% of cold compile cost.
@tool
def _warmup_tool_a(query: str, limit: int = 5) -> str:
"""Warmup tool A — never actually invoked."""
return query[:limit]
@tool
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
"""Warmup tool B — never actually invoked."""
return {"name": name, "value": value}
# A handful of common middleware so the compile pre-pays the
# ``AgentMiddleware`` resolver path. These instances never run
# because the throwaway agent is immediately collected.
# ``SubAgentMiddleware`` is the single heaviest line in cold
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
# compile its general-purpose subagent's full inner graph),
# so we include it here to make sure that compile path is JIT'd.
warmup_middleware: list = [
TodoListMiddleware(),
ModelCallLimitMiddleware(
thread_limit=120, run_limit=80, exit_behavior="end"
),
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
),
]
try:
from deepagents import SubAgentMiddleware
from deepagents.backends import StateBackend
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
gp_warmup_spec = { # type: ignore[var-annotated]
**GENERAL_PURPOSE_SUBAGENT,
"model": stub_llm,
"tools": [_warmup_tool_a],
"middleware": [TodoListMiddleware()],
}
warmup_middleware.append(
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
)
except Exception:
# Deepagents missing/incompatible — middleware-only warmup
# still produces a useful (smaller) speedup.
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
compiled = create_agent(
stub_llm,
tools=[_warmup_tool_a, _warmup_tool_b],
system_prompt="You are a warmup stub.",
middleware=warmup_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=None,
)
# Touch the compiled graph's stream_channels / nodes so any
# remaining lazy schema work fires now instead of on first
# real invocation.
_ = list(getattr(compiled, "nodes", {}).keys())
del compiled
logger.info(
"[startup] Agent JIT warmup completed in %.3fs",
_time.perf_counter() - t0,
)
except Exception:
logger.warning(
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
"real request will pay the full compile cost)",
_time.perf_counter() - t0,
exc_info=True,
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected # Tune GC: lower gen-2 threshold so long-lived garbage is collected
@ -432,6 +562,7 @@ async def lifespan(app: FastAPI):
await setup_checkpointer_tables() await setup_checkpointer_tables()
initialize_openrouter_integration() initialize_openrouter_integration()
_start_openrouter_background_refresh() _start_openrouter_background_refresh()
initialize_pricing_registration()
initialize_llm_router() initialize_llm_router()
initialize_image_gen_router() initialize_image_gen_router()
initialize_vision_llm_router() initialize_vision_llm_router()
@ -443,6 +574,18 @@ async def lifespan(app: FastAPI):
"Docs will be indexed on the next restart." "Docs will be indexed on the next restart."
) )
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
# worker readiness. ``shield`` so Uvicorn cancelling startup
# doesn't leave half-warmed Pydantic schemas in an inconsistent
# state.
try:
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
except (TimeoutError, Exception): # pragma: no cover - defensive
logging.getLogger(__name__).warning(
"[startup] Agent JIT warmup hit timeout/error — skipping; "
"first real request will pay the full compile cost."
)
log_system_snapshot("startup_complete") log_system_snapshot("startup_complete")
yield yield
@ -452,6 +595,23 @@ async def lifespan(app: FastAPI):
def registration_allowed(): def registration_allowed():
"""Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
Despite the name, this dependency does NOT only gate registration. When
REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
that could mint or refresh a session for an attacker:
* email/password ``POST /auth/register``
* email/password ``POST /auth/jwt/login``
* the Google OAuth router (``/auth/google/authorize`` and the shared
``/auth/google/callback`` handles both new signups and login for
existing users, so flipping this off locks both)
* the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
Use it as a temporary "freeze all new sessions" lever during incident
response. It is not a way to disable signup while keeping login working;
for that, override ``UserManager.oauth_callback`` instead.
"""
if not config.REGISTRATION_ENABLED: if not config.REGISTRATION_ENABLED:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled" status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
@ -596,32 +756,45 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers allow_headers=["*"], # Allows all headers
) )
app.include_router( # Password / email-based auth routers are only mounted when not running in
fastapi_users.get_auth_router(auth_backend), # Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
prefix="/auth/jwt", # POST /auth/register reachable, which is the bypass that allowed bots to
tags=["auth"], # create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
dependencies=[Depends(rate_limit_login)], if config.AUTH_TYPE != "GOOGLE":
) app.include_router(
app.include_router( fastapi_users.get_auth_router(auth_backend),
fastapi_users.get_register_router(UserRead, UserCreate), prefix="/auth/jwt",
prefix="/auth", tags=["auth"],
tags=["auth"], dependencies=[
dependencies=[ Depends(rate_limit_login),
Depends(rate_limit_register), Depends(
Depends(registration_allowed), # blocks registration when disabled registration_allowed
], ), # honour REGISTRATION_ENABLED kill switch on login too
) ],
app.include_router( )
fastapi_users.get_reset_password_router(), app.include_router(
prefix="/auth", fastapi_users.get_register_router(UserRead, UserCreate),
tags=["auth"], prefix="/auth",
dependencies=[Depends(rate_limit_password_reset)], tags=["auth"],
) dependencies=[
app.include_router( Depends(rate_limit_register),
fastapi_users.get_verify_router(UserRead), Depends(registration_allowed),
prefix="/auth", ],
tags=["auth"], )
) app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
dependencies=[Depends(rate_limit_password_reset)],
)
app.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
# /users/me (read/update profile) is needed in every auth mode, so it stays
# mounted unconditionally.
app.include_router( app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate), fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users", prefix="/users",
@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
), ),
prefix="/auth/google", prefix="/auth/google",
tags=["auth"], tags=["auth"],
dependencies=[ # REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
Depends(registration_allowed) # it blocks BOTH new OAuth signups AND login of existing OAuth users
], # blocks OAuth registration when disabled # (the fastapi-users OAuth router shares one callback for create+login,
# so this dependency closes both paths together).
dependencies=[Depends(registration_allowed)],
) )
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility # Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON # This endpoint performs a server-side redirect instead of returning JSON
# which fixes cross-site cookie issues where browsers don't send cookies # which fixes cross-site cookie issues where browsers don't send cookies
# set via cross-origin fetch requests on subsequent redirects # set via cross-origin fetch requests on subsequent redirects.
@app.get("/auth/google/authorize-redirect", tags=["auth"]) # The registration_allowed dependency mirrors the OAuth router above so
# the kill switch fails fast here instead of bouncing users to Google
# only to 403 on the callback.
@app.get(
"/auth/google/authorize-redirect",
tags=["auth"],
dependencies=[Depends(registration_allowed)],
)
async def google_authorize_redirect( async def google_authorize_redirect(
request: Request, request: Request,
): ):

View file

@ -22,10 +22,12 @@ def init_worker(**kwargs):
initialize_image_gen_router, initialize_image_gen_router,
initialize_llm_router, initialize_llm_router,
initialize_openrouter_integration, initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router, initialize_vision_llm_router,
) )
initialize_openrouter_integration() initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router() initialize_llm_router()
initialize_image_gen_router() initialize_image_gen_router()
initialize_vision_llm_router() initialize_vision_llm_router()

View file

@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f) data = yaml.safe_load(f)
configs = data.get("global_llm_configs", []) configs = data.get("global_llm_configs", [])
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {} seen_slugs: dict[str, int] = {}
for cfg in configs: for cfg in configs:
cfg.setdefault("billing_tier", "free") cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False) cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False) cfg.setdefault("seo_enabled", False)
# Capability flag: explicit YAML override always wins. When the
# operator has not annotated the model, defer to LiteLLM's
# authoritative model map (`supports_vision`) which already
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
# vision-capable. Unknown / unmapped models default-allow so
# we don't lock the user out of a freshly added third-party
# entry; the streaming-task safety net (driven by
# `is_known_text_only_chat_model`) is the only place a False
# actually blocks a request.
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
if cfg.get("seo_enabled") and cfg.get("seo_slug"): if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"] slug = cfg["seo_slug"]
@ -63,6 +89,27 @@ def load_global_llm_configs():
else: else:
seen_slugs[slug] = cfg.get("id", 0) seen_slugs[slug] = cfg.get("id", 0)
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose provider == "OPENROUTER" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
for cfg in configs:
cfg["auto_pin_tier"] = "A"
static_q = static_score_yaml(cfg)
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose provider is OPENROUTER are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.
cfg["health_gated"] = False
except Exception as e:
print(f"Warning: Failed to score global LLM configs: {e}")
return configs return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}") print(f"Warning: Failed to load global LLM configs: {e}")
@ -117,7 +164,11 @@ def load_global_image_gen_configs():
try: try:
with open(global_config_file, encoding="utf-8") as f: with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
return data.get("global_image_generation_configs", []) configs = data.get("global_image_generation_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global image generation configs: {e}") print(f"Warning: Failed to load global image generation configs: {e}")
return [] return []
@ -132,7 +183,11 @@ def load_global_vision_llm_configs():
try: try:
with open(global_config_file, encoding="utf-8") as f: with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
return data.get("global_vision_llm_configs", []) configs = data.get("global_vision_llm_configs", []) or []
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e: except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}") print(f"Warning: Failed to load global vision LLM configs: {e}")
return [] return []
@ -194,6 +249,9 @@ def load_openrouter_integration_settings() -> dict | None:
""" """
Load OpenRouter integration settings from the YAML config. Load OpenRouter integration settings from the YAML config.
Emits startup warnings for deprecated keys (``billing_tier``,
``anonymous_enabled``) and seeds their replacements for back-compat.
Returns: Returns:
dict with settings if present and enabled, None otherwise dict with settings if present and enabled, None otherwise
""" """
@ -206,9 +264,40 @@ def load_openrouter_integration_settings() -> dict | None:
with open(global_config_file, encoding="utf-8") as f: with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
settings = data.get("openrouter_integration") settings = data.get("openrouter_integration")
if settings and settings.get("enabled"): if not settings or not settings.get("enabled"):
return settings return None
return None
if "billing_tier" in settings:
print(
"Warning: openrouter_integration.billing_tier is deprecated; "
"tier is now derived per model from OpenRouter data "
"(':free' suffix or zero pricing). Remove this key."
)
if "anonymous_enabled" in settings:
print(
"Warning: openrouter_integration.anonymous_enabled is "
"deprecated; use anonymous_enabled_paid and/or "
"anonymous_enabled_free instead. Both new flags have been "
"seeded from the legacy value for back-compat."
)
settings.setdefault(
"anonymous_enabled_paid", settings["anonymous_enabled"]
)
settings.setdefault(
"anonymous_enabled_free", settings["anonymous_enabled"]
)
# Image generation + vision LLM emission are opt-in (issue L).
# OpenRouter's catalogue contains hundreds of image / vision
# capable models; auto-injecting all of them into every
# deployment would explode the model selector and surprise
# operators upgrading from prior versions. Default to False so
# admins must explicitly turn them on.
settings.setdefault("image_generation_enabled", False)
settings.setdefault("vision_enabled", False)
return settings
except Exception as e: except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}") print(f"Warning: Failed to load OpenRouter integration settings: {e}")
return None return None
@ -217,9 +306,14 @@ def load_openrouter_integration_settings() -> dict | None:
def initialize_openrouter_integration(): def initialize_openrouter_integration():
""" """
If enabled, fetch all OpenRouter models and append them to If enabled, fetch all OpenRouter models and append them to
config.GLOBAL_LLM_CONFIGS as dynamic premium entries. config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
Should be called BEFORE initialize_llm_router() so the router is derived per-model from OpenRouter's API signals (``:free`` suffix or
correctly excludes premium models from Auto mode. zero pricing), so free OpenRouter models correctly skip premium quota.
Should be called BEFORE initialize_llm_router(). Dynamic entries are
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
while user-facing Auto-mode thread pinning still considers them.
""" """
settings = load_openrouter_integration_settings() settings = load_openrouter_integration_settings()
if not settings: if not settings:
@ -235,16 +329,70 @@ def initialize_openrouter_integration():
if new_configs: if new_configs:
config.GLOBAL_LLM_CONFIGS.extend(new_configs) config.GLOBAL_LLM_CONFIGS.extend(new_configs)
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
premium_count = sum(
1 for c in new_configs if c.get("billing_tier") == "premium"
)
print( print(
f"Info: OpenRouter integration added {len(new_configs)} models " f"Info: OpenRouter integration added {len(new_configs)} models "
f"(billing_tier={settings.get('billing_tier', 'premium')})" f"(free={free_count}, premium={premium_count})"
) )
else: else:
print("Info: OpenRouter integration enabled but no models fetched") print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
image_configs = service.get_image_generation_configs()
if image_configs:
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
print(
f"Info: OpenRouter integration added {len(image_configs)} "
f"image-generation models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}") print(f"Warning: Failed to initialize OpenRouter integration: {e}")
def initialize_pricing_registration():
"""
Teach LiteLLM the per-token cost of every deployment in
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
from the OpenRouter catalogue + any operator-declared YAML pricing).
Must run AFTER ``initialize_openrouter_integration()`` so the
OpenRouter catalogue is populated and BEFORE the first LLM call so
``response_cost`` is available in ``TokenTrackingCallback``.
Failures are logged but never raised startup must not be blocked
by a missing pricing entry; the worst-case is the model debits 0.
"""
try:
from app.services.pricing_registration import (
register_pricing_from_global_configs,
)
register_pricing_from_global_configs()
except Exception as e:
print(f"Warning: Failed to register LiteLLM pricing: {e}")
def initialize_llm_router(): def initialize_llm_router():
""" """
Initialize the LLM Router service for Auto mode. Initialize the LLM Router service for Auto mode.
@ -389,14 +537,54 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
) )
# Premium token quota settings # Premium credit (micro-USD) quota settings.
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000")) #
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
# still honoured for one release as fall-back values — the prior
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
# to micros, so operators upgrading without changing their .env still
# get correct behaviour. A startup deprecation warning fires below if
# they're set.
PREMIUM_CREDIT_MICROS_LIMIT = int(
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
)
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")) STRIPE_CREDIT_MICROS_PER_UNIT = int(
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
)
STRIPE_TOKEN_BUYING_ENABLED = ( STRIPE_TOKEN_BUYING_ENABLED = (
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
) )
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
# config's ``quota_reserve_tokens`` and clamps the result to this value
# so a misconfigured "$1000/M" model can't lock the user's whole balance
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
# reserve_tokens ≈ $0.36) with headroom.
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
"PREMIUM_CREDIT_MICROS_LIMIT"
):
print(
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
"current Stripe price). The old key will be removed in a "
"future release."
)
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
"STRIPE_CREDIT_MICROS_PER_UNIT"
):
print(
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
"The old key will be removed in a future release."
)
# Anonymous / no-login mode settings # Anonymous / no-login mode settings
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
MULTI_AGENT_CHAT_ENABLED = ( MULTI_AGENT_CHAT_ENABLED = (
@ -412,6 +600,35 @@ class Config:
# Default quota reserve tokens when not specified per-model # Default quota reserve tokens when not specified per-model
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000")) QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
# ``POST /image-generations`` endpoint when the global config does not
# override it. $0.05 covers realistic worst-cases for current OpenAI /
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
)
# Per-podcast reservation (in micro-USD). One agent LLM call generating
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
# premium-model run. Tune via env.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
)
# Per-video-presentation reservation (in micro-USD). Fan-out of N
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
# plus refine retries; can produce many premium completions. $1.00
# covers worst-case. Tune via env.
#
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
# 1_000_000. The override path in ``billable_call`` bypasses the
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
# *actual* hold — raising it via env is fine but means a single video
# task can lock $1+ of credit.
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
)
# Abuse prevention: concurrent stream cap and CAPTCHA # Abuse prevention: concurrent stream cap and CAPTCHA
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2")) ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
ANON_CAPTCHA_REQUEST_THRESHOLD = int( ANON_CAPTCHA_REQUEST_THRESHOLD = int(

View file

@ -19,6 +19,24 @@
# Structure matches NewLLMConfig: # Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.) # - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled) # - Prompt configuration (system_instructions, citations_enabled)
#
# COST-BASED PREMIUM CREDITS:
# Each premium config bills the user's USD-credit balance based on the
# actual provider cost reported by LiteLLM. For models LiteLLM already
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
# or any model LiteLLM doesn't have in its built-in pricing table, declare
# per-token costs inline so they bill correctly:
#
# litellm_params:
# base_model: "my-custom-azure-deploy"
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
# input_cost_per_token: 0.000003
# output_cost_per_token: 0.000015
#
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
# API — no inline declaration needed. Models without resolvable pricing
# debit $0 from the user's balance and log a WARNING.
# Router Settings for Auto Mode # Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models # These settings control how the LiteLLM Router distributes requests across models
@ -245,31 +263,64 @@ global_llm_configs:
# ============================================================================= # =============================================================================
# When enabled, dynamically fetches ALL available models from the OpenRouter API # When enabled, dynamically fetches ALL available models from the OpenRouter API
# and injects them as global configs. This gives premium users access to any model # and injects them as global configs. This gives premium users access to any model
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota. # on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
# while free-tier OpenRouter models show up with a green Free badge and do NOT
# consume premium quota.
# Models are fetched at startup and refreshed periodically in the background. # Models are fetched at startup and refreshed periodically in the background.
# All calls go through LiteLLM with the openrouter/ prefix. # All calls go through LiteLLM with the openrouter/ prefix.
openrouter_integration: openrouter_integration:
enabled: false enabled: false
api_key: "sk-or-your-openrouter-api-key" api_key: "sk-or-your-openrouter-api-key"
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
billing_tier: "premium" # Tier is derived PER MODEL from OpenRouter's own API signals:
# anonymous_enabled: set true to also show OpenRouter models to no-login users # - id ends with ":free" -> billing_tier=free
anonymous_enabled: false # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
# - otherwise -> billing_tier=premium
# No global billing_tier knob is honored; any legacy value emits a startup warning.
# Anonymous access is split by tier so operators can expose only free
# models to no-login users without leaking paid inference.
anonymous_enabled_paid: false
anonymous_enabled_free: false
seo_enabled: false seo_enabled: false
# quota_reserve_tokens: tokens reserved per call for quota enforcement # quota_reserve_tokens: tokens reserved per call for quota enforcement
quota_reserve_tokens: 4000 quota_reserve_tokens: 4000
# id_offset: starting negative ID for dynamically generated configs. # id_offset: base negative ID for dynamically generated configs.
# Must not overlap with your static global_llm_configs IDs above. # Model IDs are derived deterministically via BLAKE2b so they survive
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
id_offset: -10000 id_offset: -10000
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
refresh_interval_hours: 24 refresh_interval_hours: 24
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits). # for per-deployment accounting when OR premium models participate in the
# These values only matter if you set billing_tier to "free" (adding them to Auto mode). # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
# For premium-only models they are cosmetic. Set conservatively or match your account tier. # real account limits live at https://openrouter.ai/settings/limits.
rpm: 200 rpm: 200
tpm: 1000000 tpm: 1000000
# Rate limits for FREE OpenRouter models. Informational only: free OR
# models are intentionally kept OUT of the LiteLLM Router pool, because
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
# 50-1000 daily requests across every ":free" model combined) —
# per-deployment router accounting can't represent a shared bucket
# correctly. Free OR models stay fully available in the model selector
# and for user-facing Auto thread pinning.
free_rpm: 20
free_tpm: 100000
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
# contains hundreds of image- and vision-capable models; turning these on
# injects them into the global Image-Generation / Vision-LLM model
# selectors alongside any static configs. Tier (free/premium) is derived
# per model the same way it is for chat (`:free` suffix or zero pricing).
# When a user picks a premium image/vision model the call debits the
# shared $5 USD-cost-based premium credit pool — so leaving these off
# avoids surprise quota burn on existing deployments. Default: false.
image_generation_enabled: false
vision_enabled: false
litellm_params: litellm_params:
max_tokens: 16384 max_tokens: 16384
system_instructions: "" system_instructions: ""

View file

@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False, default=False,
server_default="false", server_default="false",
) )
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
# config id. NULL means no pin; Auto will resolve on the next turn.
# Single-writer invariant: only app.services.auto_model_pin_service sets
# or clears this column (plus bulk clears when a search space's
# agent_llm_id changes). Unindexed: all reads are by primary key.
pinned_llm_config_id = Column(Integer, nullable=True)
# Relationships # Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads") search_space = relationship("SearchSpace", back_populates="new_chat_threads")
@ -669,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
__tablename__ = "new_chat_messages" __tablename__ = "new_chat_messages"
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
# Mirrors alembic migration 141. Lets the streaming agent and the
# legacy frontend appendMessage call coexist idempotently — the second
# writer trips the unique and recovers without creating a duplicate row.
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
__table_args__ = (
Index(
"uq_new_chat_messages_thread_turn_role",
"thread_id",
"turn_id",
"role",
unique=True,
postgresql_where=text("turn_id IS NOT NULL"),
),
)
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False) role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
# Content stored as JSONB to support rich content (text, tool calls, etc.) # Content stored as JSONB to support rich content (text, tool calls, etc.)
content = Column(JSONB, nullable=False) content = Column(JSONB, nullable=False)
@ -722,9 +745,26 @@ class TokenUsage(BaseModel, TimestampMixin):
__tablename__ = "token_usage" __tablename__ = "token_usage"
# Partial unique index on (message_id) where message_id IS NOT NULL.
# Mirrors alembic migration 142. Lets the streaming agent's
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
# because there is no per-message anchor — are unaffected.
__table_args__ = (
Index(
"uq_token_usage_message_id",
"message_id",
unique=True,
postgresql_where=text("message_id IS NOT NULL"),
),
)
prompt_tokens = Column(Integer, nullable=False, default=0) prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0)
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
model_breakdown = Column(JSONB, nullable=True) model_breakdown = Column(JSONB, nullable=True)
call_details = Column(JSONB, nullable=True) call_details = Column(JSONB, nullable=True)
@ -1787,7 +1827,15 @@ class PagePurchase(Base, TimestampMixin):
class PremiumTokenPurchase(Base, TimestampMixin): class PremiumTokenPurchase(Base, TimestampMixin):
"""Tracks Stripe checkout sessions used to grant additional premium token credits.""" """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
Note: the table name is preserved (``premium_token_purchases``) for
operational continuity even though the unit is now USD micro-credits
instead of raw tokens. The ``credit_micros_granted`` column replaced
the legacy ``tokens_granted`` in migration 140; the stored values
were not transformed because the prior $1 = 1M tokens Stripe price
makes the unit conversion 1:1 numerically.
"""
__tablename__ = "premium_token_purchases" __tablename__ = "premium_token_purchases"
__allow_unmapped__ = True __allow_unmapped__ = True
@ -1804,7 +1852,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
) )
stripe_payment_intent_id = Column(String(255), nullable=True, index=True) stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
quantity = Column(Integer, nullable=False) quantity = Column(Integer, nullable=False)
tokens_granted = Column(BigInteger, nullable=False) credit_micros_granted = Column(BigInteger, nullable=False)
amount_total = Column(Integer, nullable=True) amount_total = Column(Integer, nullable=True)
currency = Column(String(10), nullable=True) currency = Column(String(10), nullable=True)
status = Column( status = Column(
@ -2103,16 +2151,16 @@ if config.AUTH_TYPE == "GOOGLE":
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column( premium_credit_micros_limit = Column(
BigInteger, BigInteger,
nullable=False, nullable=False,
default=config.PREMIUM_TOKEN_LIMIT, default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT), server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
) )
premium_tokens_used = Column( premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
premium_tokens_reserved = Column( premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
@ -2235,16 +2283,16 @@ else:
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
premium_tokens_limit = Column( premium_credit_micros_limit = Column(
BigInteger, BigInteger,
nullable=False, nullable=False,
default=config.PREMIUM_TOKEN_LIMIT, default=config.PREMIUM_CREDIT_MICROS_LIMIT,
server_default=str(config.PREMIUM_TOKEN_LIMIT), server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
) )
premium_tokens_used = Column( premium_credit_micros_used = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )
premium_tokens_reserved = Column( premium_credit_micros_reserved = Column(
BigInteger, nullable=False, default=0, server_default="0" BigInteger, nullable=False, default=0, server_default="0"
) )

View file

@ -68,12 +68,25 @@ class EtlPipelineService:
etl_service="VISION_LLM", etl_service="VISION_LLM",
content_type="image", content_type="image",
) )
except Exception: except Exception as exc:
logging.warning( # Special-case quota exhaustion so we log a clearer message
"Vision LLM failed for %s, falling back to document parser", # — the vision LLM didn't "fail", the user just ran out of
request.filename, # premium credit. Falling through to the document parser
exc_info=True, # is a graceful degradation: OCR/Unstructured still
) # extracts text from the image without burning credit.
from app.services.billable_calls import QuotaInsufficientError
if isinstance(exc, QuotaInsufficientError):
logging.info(
"Vision LLM quota exhausted for %s; falling back to document parser",
request.filename,
)
else:
logging.warning(
"Vision LLM failed for %s, falling back to document parser",
request.filename,
exc_info=True,
)
else: else:
logging.info( logging.info(
"No vision LLM provided, falling back to document parser for %s", "No vision LLM provided, falling back to document parser for %s",

Some files were not shown because too many files have changed in this diff Show more