mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 06:12:40 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
5
.github/workflows/desktop-release.yml
vendored
5
.github/workflows/desktop-release.yml
vendored
|
|
@ -144,6 +144,11 @@ jobs:
|
||||||
APPLE_ID: ${{ secrets.APPLE_ID }}
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
# TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
|
||||||
|
# Surfaces the exact codesign / notarize commands electron-builder spawns,
|
||||||
|
# so we can see which subprocess hangs.
|
||||||
|
DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
|
||||||
|
ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
|
||||||
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
||||||
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
||||||
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
||||||
|
|
|
||||||
60
.github/workflows/notary-status.yml
vendored
Normal file
60
.github/workflows/notary-status.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
name: Notary status check
|
||||||
|
|
||||||
|
# One-off diagnostic workflow. Queries Apple's notary service to see if your
|
||||||
|
# submissions are queued, in progress, accepted, or rejected. Useful when a
|
||||||
|
# notarization seems "hung" — most often the queue itself, especially on a
|
||||||
|
# brand-new Apple Developer account.
|
||||||
|
#
|
||||||
|
# Run via: Actions tab -> "Notary status check" -> Run workflow.
|
||||||
|
# Inputs are optional; if you provide a submission ID, it also fetches that
|
||||||
|
# submission's full Apple log.
|
||||||
|
#
|
||||||
|
# Safe to delete after diagnosis.
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
submission_id:
|
||||||
|
description: 'Optional: submission UUID to fetch full Apple log for'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
status:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: List recent notarization submissions
|
||||||
|
env:
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
echo "::group::Submission history (most recent first)"
|
||||||
|
xcrun notarytool history \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID"
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Inspect specific submission (if id provided)
|
||||||
|
if: ${{ inputs.submission_id != '' }}
|
||||||
|
env:
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
SUBMISSION_ID: ${{ inputs.submission_id }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
echo "::group::Submission info"
|
||||||
|
xcrun notarytool info "$SUBMISSION_ID" \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID"
|
||||||
|
echo "::endgroup::"
|
||||||
|
echo "::group::Apple's processing log for this submission"
|
||||||
|
xcrun notarytool log "$SUBMISSION_ID" \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID" || true
|
||||||
|
echo "::endgroup::"
|
||||||
31
.vscode/launch.json
vendored
31
.vscode/launch.json
vendored
|
|
@ -26,7 +26,16 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
// Mute LangGraph/Pydantic checkpoint serializer warnings
|
||||||
|
// (UserWarnings emitted from pydantic/main.py when the
|
||||||
|
// runtime snapshots a SurfSenseContextSchema into a field
|
||||||
|
// typed `None`) so the debugger's "Raised Exceptions"
|
||||||
|
// breakpoint doesn't pause on a known-harmless event.
|
||||||
|
// Production logs are unaffected.
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Backend: FastAPI (No Reload)",
|
"name": "Backend: FastAPI (No Reload)",
|
||||||
|
|
@ -40,7 +49,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Backend: FastAPI (main.py)",
|
"name": "Backend: FastAPI (main.py)",
|
||||||
|
|
@ -54,7 +66,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Frontend: Next.js",
|
"name": "Frontend: Next.js",
|
||||||
|
|
@ -104,7 +119,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Celery: Beat Scheduler",
|
"name": "Celery: Beat Scheduler",
|
||||||
|
|
@ -124,7 +142,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"compounds": [
|
"compounds": [
|
||||||
|
|
|
||||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.0.19
|
0.0.21
|
||||||
|
|
|
||||||
|
|
@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||||
|
|
||||||
# Premium token purchases ($1 per 1M tokens for premium-tier models)
|
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||||
|
# credit; premium turns debit the actual per-call provider cost
|
||||||
|
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||||
|
|
@ -305,6 +308,24 @@ STT_SERVICE=local/base
|
||||||
# Advanced (optional)
|
# Advanced (optional)
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# New-chat agent feature flags
|
||||||
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
|
SURFSENSE_ENABLE_BUSY_MUTEX=true
|
||||||
|
SURFSENSE_ENABLE_SKILLS=true
|
||||||
|
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
|
||||||
|
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
|
||||||
|
SURFSENSE_ENABLE_ACTION_LOG=true
|
||||||
|
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||||
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
# Periodic connector sync interval (default: 5m)
|
# Periodic connector sync interval (default: 5m)
|
||||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
||||||
|
|
@ -315,9 +336,24 @@ STT_SERVICE=local/base
|
||||||
# Pages limit per user for ETL (default: unlimited)
|
# Pages limit per user for ETL (default: unlimited)
|
||||||
# PAGES_LIMIT=500
|
# PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 5M)
|
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||||
# Only applies to models with billing_tier=premium in global_llm_config.yaml
|
# Premium turns are debited at the actual per-call provider cost reported
|
||||||
# PREMIUM_TOKEN_LIMIT=5000000
|
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||||
|
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||||
|
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
|
||||||
|
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation for the podcast Celery task ($0.20 default).
|
||||||
|
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation for the video Celery task ($1.00 default).
|
||||||
|
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — public users can chat without an account
|
# No-login (anonymous) mode — public users can chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
|
||||||
# Set FALSE to disable new checkout session creation temporarily
|
# Set FALSE to disable new checkout session creation temporarily
|
||||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||||
|
|
||||||
# Premium token purchases via Stripe (for premium-tier model usage)
|
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||||
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
|
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||||
|
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||||
|
# per-call provider cost reported by LiteLLM.
|
||||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||||
STRIPE_TOKENS_PER_UNIT=1000000
|
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||||
|
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||||
|
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||||
|
|
||||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||||
|
|
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||||
PAGES_LIMIT=500
|
PAGES_LIMIT=500
|
||||||
|
|
||||||
# Premium token quota per registered user (default: 3,000,000)
|
# Premium credit quota per registered user, in micro-USD
|
||||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||||
PREMIUM_TOKEN_LIMIT=3000000
|
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||||
|
# models bill proportionally. Applies only to models with
|
||||||
|
# billing_tier=premium in global_llm_config.yaml.
|
||||||
|
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||||
|
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||||
|
# PREMIUM_TOKEN_LIMIT=5000000
|
||||||
|
|
||||||
|
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||||
|
# stream_new_chat estimates an upper-bound cost from the model's
|
||||||
|
# litellm-published per-token rates × the config's quota_reserve_tokens
|
||||||
|
# and clamps to this value so a misconfigured model can't lock the
|
||||||
|
# user's whole balance on one call. Default $1.00.
|
||||||
|
QUOTA_MAX_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
|
# Per-image reservation (in micro-USD) for the POST /image-generations
|
||||||
|
# endpoint. Bypassed for free configs. Default $0.05.
|
||||||
|
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||||
|
|
||||||
|
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
|
||||||
|
# Single envelope covers one transcript-generation LLM call. Default $0.20.
|
||||||
|
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||||
|
|
||||||
|
# Per-video-presentation reservation (in micro-USD) used by the video
|
||||||
|
# presentation Celery task. Covers worst-case fan-out of N slide-scene
|
||||||
|
# generations + refines. Default $1.00. NOTE: tasks using the override
|
||||||
|
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||||
|
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||||
|
|
||||||
# No-login (anonymous) mode — allows public users to chat without an account
|
# No-login (anonymous) mode — allows public users to chat without an account
|
||||||
# Set TRUE to enable /free pages and anonymous chat API
|
# Set TRUE to enable /free pages and anonymous chat API
|
||||||
|
|
@ -297,3 +327,30 @@ LANGSMITH_PROJECT=surfsense
|
||||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
# Comma-separated allowlist of plugin entry-point names
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
|
||||||
|
# on a cold turn) is reused across subsequent turns on the same thread,
|
||||||
|
# collapsing it to a microsecond hash lookup. All connector tools acquire
|
||||||
|
# their own short-lived DB session per call (Phase 2 refactor) so a cached
|
||||||
|
# closure is safe to share across requests. Flip OFF only as a last-resort
|
||||||
|
# rollback if you suspect cache-related staleness.
|
||||||
|
# SURFSENSE_ENABLE_AGENT_CACHE=true
|
||||||
|
|
||||||
|
# Cache capacity (max number of compiled-agent entries kept in memory)
|
||||||
|
# and TTL per entry (seconds). Working set is typically one entry per
|
||||||
|
# active thread on this replica; tune up for very large deployments.
|
||||||
|
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
|
||||||
|
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Connector discovery TTL cache (Phase 1.4 perf optimization)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Caches the per-search-space "available connectors" + "available document
|
||||||
|
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
|
||||||
|
# ORM event listeners auto-invalidate on connector / document inserts,
|
||||||
|
# updates and deletes — the TTL only bounds staleness for bulk-import
|
||||||
|
# paths that bypass the ORM. Set to 0 to disable the cache.
|
||||||
|
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30
|
||||||
|
|
|
||||||
|
|
@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
|
||||||
COPY pyproject.toml .
|
COPY pyproject.toml .
|
||||||
COPY uv.lock .
|
COPY uv.lock .
|
||||||
|
|
||||||
# Install PyTorch based on architecture
|
# Install all Python dependencies from uv.lock for deterministic builds.
|
||||||
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
#
|
||||||
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
|
# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
|
||||||
else \
|
# which lets prod silently drift to newer upstream versions on every rebuild
|
||||||
pip install --no-cache-dir torch torchvision torchaudio; \
|
# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
|
||||||
fi
|
# Exporting the lock to requirements.txt and feeding it to `uv pip install`
|
||||||
|
# pins every transitive package to the exact version captured in uv.lock.
|
||||||
# Install python dependencies
|
#
|
||||||
|
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
|
||||||
|
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
|
||||||
|
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
|
||||||
|
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
|
||||||
|
# downloads that the lock-based install immediately replaced. If a specific
|
||||||
|
# CUDA version is needed (driver compatibility, etc.), wire it through
|
||||||
|
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
|
||||||
RUN pip install --no-cache-dir uv && \
|
RUN pip install --no-cache-dir uv && \
|
||||||
uv pip install --system --no-cache-dir -e .
|
uv export --frozen --no-dev --no-hashes --no-emit-project \
|
||||||
|
--format requirements-txt -o /tmp/requirements.txt && \
|
||||||
|
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
|
||||||
|
rm /tmp/requirements.txt
|
||||||
|
|
||||||
# Set SSL environment variables dynamically
|
# Set SSL environment variables dynamically
|
||||||
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
|
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
|
||||||
|
|
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
|
||||||
# Pre-download Docling models
|
# Pre-download Docling models
|
||||||
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
|
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
|
||||||
|
|
||||||
# Install Playwright browsers for web scraping if needed
|
# Install Playwright browsers for web scraping (the playwright package itself
|
||||||
RUN pip install playwright && \
|
# is already installed via uv.lock above)
|
||||||
playwright install chromium --with-deps
|
RUN playwright install chromium --with-deps
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
# Install the project itself in editable mode. Dependencies were already
|
||||||
|
# installed deterministically from uv.lock above, so --no-deps prevents any
|
||||||
|
# re-resolution that could pull newer versions.
|
||||||
|
RUN uv pip install --system --no-cache-dir --no-deps -e .
|
||||||
|
|
||||||
# Copy and set permissions for entrypoint script
|
# Copy and set permissions for entrypoint script
|
||||||
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
|
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
|
||||||
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
|
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""138_add_thread_auto_model_pinning_fields
|
||||||
|
|
||||||
|
Revision ID: 138
|
||||||
|
Revises: 137
|
||||||
|
Create Date: 2026-04-30
|
||||||
|
|
||||||
|
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||||
|
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||||
|
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||||
|
|
||||||
|
The column is unindexed: all reads are by new_chat_threads.id (primary key),
|
||||||
|
so a secondary index would be dead write amplification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "138"
|
||||||
|
down_revision: str | None = "137"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads "
|
||||||
|
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop any shape the thread row may be carrying. The extra columns and
|
||||||
|
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
|
||||||
|
# makes each statement a safe no-op on the lean shape.
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
"""add user table to zero_publication with column list
|
||||||
|
|
||||||
|
Adds the "user" table to zero_publication with a column-list publication
|
||||||
|
so that only the 5 fields driving the live usage meters are replicated
|
||||||
|
through WAL -> zero-cache -> browser IndexedDB:
|
||||||
|
|
||||||
|
id, pages_limit, pages_used,
|
||||||
|
premium_tokens_limit, premium_tokens_used
|
||||||
|
|
||||||
|
Sensitive columns (hashed_password, email, oauth_account, display_name,
|
||||||
|
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
|
||||||
|
included in the publication, so they never enter WAL replication.
|
||||||
|
|
||||||
|
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
|
||||||
|
(it is already DEFAULT today since "user" was never in the
|
||||||
|
TABLES_WITH_FULL_IDENTITY list of migration 117).
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Revision ID: 139
|
||||||
|
Revises: 138
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "139"
|
||||||
|
down_revision: str | None = "138"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Document column list as left by migration 117. Must match exactly.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
|
||||||
|
# Buy Tokens content). Keep this list narrow on purpose: anything added
|
||||||
|
# here flows into WAL and IndexedDB for every connected browser.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
documents_has_zero_ver: bool, user_has_zero_ver: bool
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
|
||||||
|
# opened one via context.begin_transaction(), but the driver still errors
|
||||||
|
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
|
||||||
|
# migration 117, so this is already DEFAULT. Re-assert anyway so
|
||||||
|
# the column-list publication stays valid (DEFAULT identity only
|
||||||
|
# requires the PK to be in the column list).
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""rename premium token columns to credit micros and add cost_micros to token_usage
|
||||||
|
|
||||||
|
Migrates the premium quota system from a flat token counter to a USD-cost
|
||||||
|
based credit system, where 1 credit = 1 micro-USD ($0.000001).
|
||||||
|
|
||||||
|
Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
|
||||||
|
price means every existing value is already correct in the new unit, no
|
||||||
|
data transformation needed):
|
||||||
|
|
||||||
|
user.premium_tokens_limit -> premium_credit_micros_limit
|
||||||
|
user.premium_tokens_used -> premium_credit_micros_used
|
||||||
|
user.premium_tokens_reserved -> premium_credit_micros_reserved
|
||||||
|
|
||||||
|
premium_token_purchases.tokens_granted -> credit_micros_granted
|
||||||
|
|
||||||
|
New column for cost auditing per turn:
|
||||||
|
|
||||||
|
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
|
||||||
|
|
||||||
|
The "user" table is in zero_publication's column list (added in 139), so
|
||||||
|
this migration must drop and recreate the publication with the renamed
|
||||||
|
column names, otherwise zero-cache will replicate stale column names and
|
||||||
|
the FE Zero schema will fail to bind.
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
|
||||||
|
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
|
||||||
|
column-not-found errors from a stale catalog snapshot.
|
||||||
|
|
||||||
|
Revision ID: 140
|
||||||
|
Revises: 139
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "140"
|
||||||
|
down_revision: str | None = "139"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Replicates 139's document column list verbatim — must stay in sync.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Same five live-meter fields as 139, with the renamed column names.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
"premium_credit_micros_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _column_exists(conn, table: str, column: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = :col"
|
||||||
|
),
|
||||||
|
{"tbl": table, "col": column},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
user_cols: list[str],
|
||||||
|
*,
|
||||||
|
documents_has_zero_ver: bool,
|
||||||
|
user_has_zero_ver: bool,
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_col_list_with_meta = user_cols + (
|
||||||
|
['"_0_version"'] if user_has_zero_ver else []
|
||||||
|
)
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_col_list_with_meta)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
|
||||||
|
# dev environments are safe.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if not _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.add_column(
|
||||||
|
"token_usage",
|
||||||
|
sa.Column(
|
||||||
|
"cost_micros",
|
||||||
|
sa.BigInteger(),
|
||||||
|
nullable=False,
|
||||||
|
server_default="0",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "tokens_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"tokens_granted",
|
||||||
|
new_column_name="credit_micros_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
|
||||||
|
#
|
||||||
|
# We must drop the publication first (it references the old column
|
||||||
|
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
|
||||||
|
# in a transaction block; alembic's outer transaction already holds
|
||||||
|
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
|
||||||
|
# publications require at least the PK to be in the column list,
|
||||||
|
# which is true for both the old and new shape.
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
# Drop the publication BEFORE renaming columns, otherwise Postgres
|
||||||
|
# rejects the rename: "cannot drop column ... referenced by
|
||||||
|
# publication".
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for old, new in (
|
||||||
|
("premium_tokens_limit", "premium_credit_micros_limit"),
|
||||||
|
("premium_tokens_used", "premium_credit_micros_used"),
|
||||||
|
("premium_tokens_reserved", "premium_credit_micros_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", old) and not _column_exists(
|
||||||
|
conn, "user", new
|
||||||
|
):
|
||||||
|
op.alter_column("user", old, new_column_name=new)
|
||||||
|
|
||||||
|
# Update the server_default on the renamed limit column so newly
|
||||||
|
# inserted users get $5 of credit (== 5_000_000 micros) by
|
||||||
|
# default. Existing rows are unaffected.
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_credit_micros_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate the publication with the new column names.
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
USER_COLS,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Revert the rename and drop ``cost_micros``.
|
||||||
|
|
||||||
|
Mirrors ``upgrade``: drop the publication, rename columns back, drop
|
||||||
|
the new column, recreate the publication with the old column list.
|
||||||
|
Same zero-cache stop/reset runbook applies in reverse.
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
for new, old in (
|
||||||
|
("premium_credit_micros_limit", "premium_tokens_limit"),
|
||||||
|
("premium_credit_micros_used", "premium_tokens_used"),
|
||||||
|
("premium_credit_micros_reserved", "premium_tokens_reserved"),
|
||||||
|
):
|
||||||
|
if _column_exists(conn, "user", new) and not _column_exists(
|
||||||
|
conn, "user", old
|
||||||
|
):
|
||||||
|
op.alter_column("user", new, new_column_name=old)
|
||||||
|
|
||||||
|
op.alter_column(
|
||||||
|
"user",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
server_default="5000000",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_user_cols = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
_build_publication_ddl(
|
||||||
|
legacy_user_cols,
|
||||||
|
documents_has_zero_ver=documents_has_zero_ver,
|
||||||
|
user_has_zero_ver=user_has_zero_ver,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(
|
||||||
|
conn, "premium_token_purchases", "credit_micros_granted"
|
||||||
|
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
|
||||||
|
op.alter_column(
|
||||||
|
"premium_token_purchases",
|
||||||
|
"credit_micros_granted",
|
||||||
|
new_column_name="tokens_granted",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _column_exists(conn, "token_usage", "cost_micros"):
|
||||||
|
op.drop_column("token_usage", "cost_micros")
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""141_unique_chat_message_turn_role
|
||||||
|
|
||||||
|
Revision ID: 141
|
||||||
|
Revises: 140
|
||||||
|
Create Date: 2026-05-04
|
||||||
|
|
||||||
|
Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)``
|
||||||
|
where ``turn_id IS NOT NULL``.
|
||||||
|
|
||||||
|
Why
|
||||||
|
---
|
||||||
|
The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being
|
||||||
|
moved to write its own ``new_chat_messages`` rows server-side instead of
|
||||||
|
relying on the frontend's later ``POST /threads/{id}/messages`` call. This
|
||||||
|
closes the "ghost-thread" abuse vector where authenticated callers got free
|
||||||
|
LLM completions while ``new_chat_messages`` stayed empty.
|
||||||
|
|
||||||
|
For server-side and legacy frontend writes to coexist we need an idempotency
|
||||||
|
key. The natural triple is ``(thread_id, turn_id, role)``: the server issues
|
||||||
|
exactly one ``turn_id`` per turn, and a turn produces at most one user
|
||||||
|
message and one assistant message. Whichever side wins the race writes the
|
||||||
|
row; the loser hits ``IntegrityError`` and recovers gracefully.
|
||||||
|
|
||||||
|
Partial — ``WHERE turn_id IS NOT NULL`` — so:
|
||||||
|
|
||||||
|
* Legacy rows that predate the ``turn_id`` column (migration 136) keep
|
||||||
|
co-existing without de-dup.
|
||||||
|
* Clone / snapshot inserts in
|
||||||
|
``app/services/public_chat_service.py`` that build ``NewChatMessage``
|
||||||
|
without ``turn_id`` are unaffected (multiple snapshot copies of the same
|
||||||
|
user/assistant pair are intentional).
|
||||||
|
|
||||||
|
This index coexists with the existing single-column ``ix_new_chat_messages_turn_id``
|
||||||
|
from migration 136 — no collision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "141"
|
||||||
|
down_revision: str | None = "140"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_NAME = "uq_new_chat_messages_thread_turn_role"
|
||||||
|
TABLE_NAME = "new_chat_messages"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_index(
|
||||||
|
INDEX_NAME,
|
||||||
|
TABLE_NAME,
|
||||||
|
["thread_id", "turn_id", "role"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("turn_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
"""142_token_usage_message_id_unique
|
||||||
|
|
||||||
|
Revision ID: 142
|
||||||
|
Revises: 141
|
||||||
|
Create Date: 2026-05-04
|
||||||
|
|
||||||
|
Add a partial unique index on ``token_usage(message_id)`` where
|
||||||
|
``message_id IS NOT NULL``.
|
||||||
|
|
||||||
|
Why
|
||||||
|
---
|
||||||
|
Two writers can race on the same assistant turn's ``token_usage`` row:
|
||||||
|
|
||||||
|
* ``finalize_assistant_turn`` (server-side, called from the streaming
|
||||||
|
finally block in ``stream_new_chat`` / ``stream_resume_chat``)
|
||||||
|
* ``append_message``'s recovery branch in
|
||||||
|
``app/routes/new_chat_routes.py`` (legacy frontend round-trip)
|
||||||
|
|
||||||
|
Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a
|
||||||
|
micro-second-aligned race could observe "no row" on each side and double
|
||||||
|
INSERT, producing duplicate ``token_usage`` rows for the same
|
||||||
|
``message_id``.
|
||||||
|
|
||||||
|
A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``)
|
||||||
|
turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING``
|
||||||
|
no-ops for the loser, hard-eliminating the race at the DB level. Partial
|
||||||
|
because non-chat usage rows (indexing, image generation, podcasts) keep
|
||||||
|
``message_id`` NULL — they're per-event, no de-dup needed.
|
||||||
|
|
||||||
|
Pre-flight
|
||||||
|
----------
|
||||||
|
Today's schema only has a non-unique index on ``message_id`` so a
|
||||||
|
duplicate population could already exist from any past race. We:
|
||||||
|
|
||||||
|
* Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``).
|
||||||
|
* If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50)
|
||||||
|
we dedupe by deleting all but the smallest ``id`` per group.
|
||||||
|
* If the count exceeds the threshold we abort with a descriptive
|
||||||
|
error rather than silently mutate prod data — operator must
|
||||||
|
investigate before retrying.
|
||||||
|
|
||||||
|
Concurrency
|
||||||
|
-----------
|
||||||
|
``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid
|
||||||
|
stalling production writes during deploy (a regular ``CREATE INDEX``
|
||||||
|
holds an ACCESS EXCLUSIVE lock for the duration of the build, which
|
||||||
|
would block ``token_usage`` INSERTs for every active streaming chat).
|
||||||
|
The trade-off is a slower migration (CONCURRENTLY scans the table
|
||||||
|
twice) and the ``CREATE`` statement cannot run inside alembic's default
|
||||||
|
transaction wrapper — ``autocommit_block()`` handles that.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "142"
|
||||||
|
down_revision: str | None = "141"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_NAME = "uq_token_usage_message_id"
|
||||||
|
TABLE_NAME = "token_usage"
|
||||||
|
|
||||||
|
# Refuse to silently mutate prod data if the duplicate population is
|
||||||
|
# unexpectedly large — operator should investigate the upstream cause
|
||||||
|
# before retrying. 50 is comfortably above any plausible duplicate
|
||||||
|
# count from the existing race window (the race is microseconds wide).
|
||||||
|
DUPLICATE_ABORT_THRESHOLD = 50
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
dup_groups = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT message_id, COUNT(*) AS n "
|
||||||
|
"FROM token_usage "
|
||||||
|
"WHERE message_id IS NOT NULL "
|
||||||
|
"GROUP BY message_id "
|
||||||
|
"HAVING COUNT(*) > 1"
|
||||||
|
)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"token_usage has {len(dup_groups)} duplicate message_id groups "
|
||||||
|
f"(threshold={DUPLICATE_ABORT_THRESHOLD}). "
|
||||||
|
"Resolve the duplicates manually before re-running this migration."
|
||||||
|
)
|
||||||
|
|
||||||
|
if dup_groups:
|
||||||
|
# Delete all but the smallest-id row per duplicate group. The
|
||||||
|
# smallest id is by definition the earliest insert, so we keep
|
||||||
|
# the row most likely to reflect the actual stream's first
|
||||||
|
# successful write.
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
DELETE FROM token_usage
|
||||||
|
WHERE id IN (
|
||||||
|
SELECT id FROM (
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
row_number() OVER (
|
||||||
|
PARTITION BY message_id ORDER BY id ASC
|
||||||
|
) AS rn
|
||||||
|
FROM token_usage
|
||||||
|
WHERE message_id IS NOT NULL
|
||||||
|
) ranked
|
||||||
|
WHERE rn > 1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop
|
||||||
|
# alembic's auto-transaction for this op only.
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
op.execute(
|
||||||
|
f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} "
|
||||||
|
f"ON {TABLE_NAME} (message_id) "
|
||||||
|
"WHERE message_id IS NOT NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}")
|
||||||
|
|
@ -11,7 +11,6 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from .middleware import build_main_agent_deepagent_middleware
|
|
||||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||||
ToolsPermissions,
|
ToolsPermissions,
|
||||||
)
|
)
|
||||||
|
|
@ -20,6 +19,8 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from .middleware import build_main_agent_deepagent_middleware
|
||||||
|
|
||||||
|
|
||||||
def build_compiled_agent_graph_sync(
|
def build_compiled_agent_graph_sync(
|
||||||
*,
|
*,
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ from .propagation import (
|
||||||
from .resume import (
|
from .resume import (
|
||||||
build_resume_command,
|
build_resume_command,
|
||||||
fan_out_decisions_to_match,
|
fan_out_decisions_to_match,
|
||||||
hitlrequest_action_count,
|
|
||||||
get_first_pending_subagent_interrupt,
|
get_first_pending_subagent_interrupt,
|
||||||
|
hitlrequest_action_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -51,7 +51,9 @@ def build_task_tool_with_parent_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
if task_description is None:
|
if task_description is None:
|
||||||
description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str)
|
description = TASK_TOOL_DESCRIPTION.format(
|
||||||
|
available_agents=subagent_description_str
|
||||||
|
)
|
||||||
elif "{available_agents}" in task_description:
|
elif "{available_agents}" in task_description:
|
||||||
description = task_description.format(available_agents=subagent_description_str)
|
description = task_description.format(available_agents=subagent_description_str)
|
||||||
else:
|
else:
|
||||||
|
|
@ -90,11 +92,11 @@ def build_task_tool_with_parent_config(
|
||||||
def task(
|
def task(
|
||||||
description: Annotated[
|
description: Annotated[
|
||||||
str,
|
str,
|
||||||
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
|
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
|
||||||
],
|
],
|
||||||
subagent_type: Annotated[
|
subagent_type: Annotated[
|
||||||
str,
|
str,
|
||||||
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
|
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
|
||||||
],
|
],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
|
|
@ -119,7 +121,9 @@ def build_task_tool_with_parent_config(
|
||||||
if callable(get_state):
|
if callable(get_state):
|
||||||
try:
|
try:
|
||||||
snapshot = get_state(sub_config)
|
snapshot = get_state(sub_config)
|
||||||
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
||||||
|
snapshot
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fail loud if a resume is queued: silent fallback would
|
# Fail loud if a resume is queued: silent fallback would
|
||||||
# replay the original interrupt to the user.
|
# replay the original interrupt to the user.
|
||||||
|
|
@ -158,11 +162,11 @@ def build_task_tool_with_parent_config(
|
||||||
async def atask(
|
async def atask(
|
||||||
description: Annotated[
|
description: Annotated[
|
||||||
str,
|
str,
|
||||||
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
|
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
|
||||||
],
|
],
|
||||||
subagent_type: Annotated[
|
subagent_type: Annotated[
|
||||||
str,
|
str,
|
||||||
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
|
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
|
||||||
],
|
],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
|
|
@ -186,7 +190,9 @@ def build_task_tool_with_parent_config(
|
||||||
if callable(aget_state):
|
if callable(aget_state):
|
||||||
try:
|
try:
|
||||||
snapshot = await aget_state(sub_config)
|
snapshot = await aget_state(sub_config)
|
||||||
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
pending_id, pending_value = get_first_pending_subagent_interrupt(
|
||||||
|
snapshot
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
if has_surfsense_resume(runtime):
|
if has_surfsense_resume(runtime):
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from ...context_prune.prune_tool_names import safe_exclude_tools
|
|
||||||
from app.agents.multi_agent_chat.subagents import (
|
from app.agents.multi_agent_chat.subagents import (
|
||||||
build_subagents,
|
build_subagents,
|
||||||
get_subagents_to_exclude,
|
get_subagents_to_exclude,
|
||||||
|
|
@ -66,6 +65,7 @@ from app.agents.new_chat.plugin_loader import (
|
||||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from ...context_prune.prune_tool_names import safe_exclude_tools
|
||||||
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,10 @@ from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
from app.agents.multi_agent_chat.subagents import (
|
||||||
from ..tools import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
|
get_subagents_to_exclude,
|
||||||
|
main_prompt_registry_subagent_lines,
|
||||||
|
)
|
||||||
from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
|
from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
|
||||||
load_mcp_tools_by_connector,
|
load_mcp_tools_by_connector,
|
||||||
)
|
)
|
||||||
|
|
@ -24,17 +26,19 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.multi_agent_chat.subagents import (
|
|
||||||
get_subagents_to_exclude,
|
|
||||||
main_prompt_registry_subagent_lines,
|
|
||||||
)
|
|
||||||
from ..system_prompt import build_main_agent_system_prompt
|
|
||||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||||
|
from ..system_prompt import build_main_agent_system_prompt
|
||||||
|
from ..tools import (
|
||||||
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||||
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||||
|
)
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .index import MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
|
from .index import (
|
||||||
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||||
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"]
|
__all__ = ["MAIN_AGENT_SURFSENSE_TOOL_NAMES", "MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED"]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,9 @@ from .resume import create_generate_resume_tool
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
from .video_presentation import create_generate_video_presentation_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||||
podcast = create_generate_podcast_tool(
|
podcast = create_generate_podcast_tool(
|
||||||
search_space_id=resolved_dependencies["search_space_id"],
|
search_space_id=resolved_dependencies["search_space_id"],
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from app.db import ChatVisibility
|
||||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||||
if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE:
|
if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE:
|
||||||
mem = create_update_team_memory_tool(
|
mem = create_update_team_memory_tool(
|
||||||
|
|
@ -18,7 +20,10 @@ def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) ->
|
||||||
db_session=resolved_dependencies["db_session"],
|
db_session=resolved_dependencies["db_session"],
|
||||||
llm=resolved_dependencies.get("llm"),
|
llm=resolved_dependencies.get("llm"),
|
||||||
)
|
)
|
||||||
return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []}
|
return {
|
||||||
|
"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}],
|
||||||
|
"ask": [],
|
||||||
|
}
|
||||||
mem = create_update_memory_tool(
|
mem = create_update_memory_tool(
|
||||||
user_id=resolved_dependencies["user_id"],
|
user_id=resolved_dependencies["user_id"],
|
||||||
db_session=resolved_dependencies["db_session"],
|
db_session=resolved_dependencies["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,20 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
from .web_search import create_web_search_tool
|
from .web_search import create_web_search_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||||
web = create_web_search_tool(
|
web = create_web_search_tool(
|
||||||
search_space_id=resolved_dependencies.get("search_space_id"),
|
search_space_id=resolved_dependencies.get("search_space_id"),
|
||||||
available_connectors=resolved_dependencies.get("available_connectors"),
|
available_connectors=resolved_dependencies.get("available_connectors"),
|
||||||
)
|
)
|
||||||
scrape = create_scrape_webpage_tool(firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key"))
|
scrape = create_scrape_webpage_tool(
|
||||||
docs = create_search_surfsense_docs_tool(db_session=resolved_dependencies["db_session"])
|
firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key")
|
||||||
|
)
|
||||||
|
docs = create_search_surfsense_docs_tool(
|
||||||
|
db_session=resolved_dependencies["db_session"]
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"allow": [
|
"allow": [
|
||||||
{"name": getattr(web, "name", "") or "", "tool": web},
|
{"name": getattr(web, "name", "") or "", "tool": web},
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
_ = {**(dependencies or {}), **kwargs}
|
_ = {**(dependencies or {}), **kwargs}
|
||||||
return {"allow": [], "ask": []}
|
return {"allow": [], "ask": []}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,9 @@ from .search_events import create_search_calendar_events_tool
|
||||||
from .update_event import create_update_calendar_event_tool
|
from .update_event import create_update_calendar_event_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||||
session_dependencies = {
|
session_dependencies = {
|
||||||
"db_session": resolved_dependencies["db_session"],
|
"db_session": resolved_dependencies["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
_ = {**(dependencies or {}), **kwargs}
|
_ = {**(dependencies or {}), **kwargs}
|
||||||
return {"allow": [], "ask": []}
|
return {"allow": [], "ask": []}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .delete_page import create_delete_confluence_page_tool
|
||||||
from .update_page import create_update_confluence_page_tool
|
from .update_page import create_update_confluence_page_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||||
session_dependencies = {
|
session_dependencies = {
|
||||||
"db_session": resolved_dependencies["db_session"],
|
"db_session": resolved_dependencies["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .read_messages import create_read_discord_messages_tool
|
||||||
from .send_message import create_send_discord_message_tool
|
from .send_message import create_send_discord_message_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from .create_file import create_create_dropbox_file_tool
|
||||||
from .trash_file import create_delete_dropbox_file_tool
|
from .trash_file import create_delete_dropbox_file_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,9 @@ from .trash_email import create_trash_gmail_email_tool
|
||||||
from .update_draft import create_update_gmail_draft_tool
|
from .update_draft import create_update_gmail_draft_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from .create_file import create_create_google_drive_file_tool
|
||||||
from .trash_file import create_delete_google_drive_file_tool
|
from .trash_file import create_delete_google_drive_file_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .delete_issue import create_delete_jira_issue_tool
|
||||||
from .update_issue import create_update_jira_issue_tool
|
from .update_issue import create_update_jira_issue_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .delete_issue import create_delete_linear_issue_tool
|
||||||
from .update_issue import create_update_linear_issue_tool
|
from .update_issue import create_update_linear_issue_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .list_events import create_list_luma_events_tool
|
||||||
from .read_event import create_read_luma_event_tool
|
from .read_event import create_read_luma_event_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .delete_page import create_delete_notion_page_tool
|
||||||
from .update_page import create_update_notion_page_tool
|
from .update_page import create_update_notion_page_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from .create_file import create_create_onedrive_file_tool
|
||||||
from .trash_file import create_delete_onedrive_file_tool
|
from .trash_file import create_delete_onedrive_file_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
_ = {**(dependencies or {}), **kwargs}
|
_ = {**(dependencies or {}), **kwargs}
|
||||||
return {"allow": [], "ask": []}
|
return {"allow": [], "ask": []}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from .read_messages import create_read_teams_messages_tool
|
||||||
from .send_message import create_send_teams_message_tool
|
from .send_message import create_send_teams_message_tool
|
||||||
|
|
||||||
|
|
||||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
def load_tools(
|
||||||
|
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> ToolsPermissions:
|
||||||
d = {**(dependencies or {}), **kwargs}
|
d = {**(dependencies or {}), **kwargs}
|
||||||
common = {
|
common = {
|
||||||
"db_session": d["db_session"],
|
"db_session": d["db_session"],
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
## Helper functions for fetching connector metadata maps
|
## Helper functions for fetching connector metadata maps
|
||||||
|
|
||||||
|
|
||||||
async def fetch_mcp_connector_metadata_maps(
|
async def fetch_mcp_connector_metadata_maps(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -58,6 +59,7 @@ async def fetch_mcp_connector_metadata_maps(
|
||||||
|
|
||||||
## Helper functions for partitioning tools by connector agent
|
## Helper functions for partitioning tools by connector agent
|
||||||
|
|
||||||
|
|
||||||
def partition_mcp_tools_by_connector(
|
def partition_mcp_tools_by_connector(
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
connector_id_to_type: dict[int, str],
|
connector_id_to_type: dict[int, str],
|
||||||
|
|
@ -104,8 +106,10 @@ def partition_mcp_tools_by_connector(
|
||||||
|
|
||||||
return dict(buckets)
|
return dict(buckets)
|
||||||
|
|
||||||
|
|
||||||
## Helper functions for splitting tools by permissions
|
## Helper functions for splitting tools by permissions
|
||||||
|
|
||||||
|
|
||||||
def _get_mcp_tool_name(tool: BaseTool) -> str:
|
def _get_mcp_tool_name(tool: BaseTool) -> str:
|
||||||
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||||
orig = meta.get("mcp_original_tool_name")
|
orig = meta.get("mcp_original_tool_name")
|
||||||
|
|
@ -139,6 +143,7 @@ def _split_tools_by_permissions(
|
||||||
|
|
||||||
## Main function to load MCP tools and split them by permissions for each connector agent
|
## Main function to load MCP tools and split them by permissions for each connector agent
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools_by_connector(
|
async def load_mcp_tools_by_connector(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -148,9 +153,7 @@ async def load_mcp_tools_by_connector(
|
||||||
Pass ``bypass_internal_hitl=True`` so the subagent's
|
Pass ``bypass_internal_hitl=True`` so the subagent's
|
||||||
``HumanInTheLoopMiddleware`` is the single HITL gate.
|
``HumanInTheLoopMiddleware`` is the single HITL gate.
|
||||||
"""
|
"""
|
||||||
flat = await load_mcp_tools(
|
flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True)
|
||||||
session, search_space_id, bypass_internal_hitl=True
|
|
||||||
)
|
|
||||||
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
|
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
|
||||||
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
|
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,9 @@ from typing import Any, Protocol
|
||||||
from deepagents import SubAgent
|
from deepagents import SubAgent
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.constants import (
|
||||||
|
SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
|
||||||
|
)
|
||||||
from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import (
|
from app.agents.multi_agent_chat.subagents.builtins.deliverables.agent import (
|
||||||
build_subagent as build_deliverables_subagent,
|
build_subagent as build_deliverables_subagent,
|
||||||
)
|
)
|
||||||
|
|
@ -62,9 +65,6 @@ from app.agents.multi_agent_chat.subagents.connectors.slack.agent import (
|
||||||
from app.agents.multi_agent_chat.subagents.connectors.teams.agent import (
|
from app.agents.multi_agent_chat.subagents.connectors.teams.agent import (
|
||||||
build_subagent as build_teams_subagent,
|
build_subagent as build_teams_subagent,
|
||||||
)
|
)
|
||||||
from app.agents.multi_agent_chat.constants import (
|
|
||||||
SUBAGENT_TO_REQUIRED_CONNECTOR_MAP,
|
|
||||||
)
|
|
||||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||||
read_md_file,
|
read_md_file,
|
||||||
)
|
)
|
||||||
|
|
@ -105,6 +105,7 @@ SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = {
|
||||||
"teams": build_teams_subagent,
|
"teams": build_teams_subagent,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _route_resource_package(builder: SubagentBuilder) -> str:
|
def _route_resource_package(builder: SubagentBuilder) -> str:
|
||||||
mod = builder.__module__
|
mod = builder.__module__
|
||||||
return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0]
|
return mod[: -len(".agent")] if mod.endswith(".agent") else mod.rsplit(".", 1)[0]
|
||||||
|
|
|
||||||
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
|
|
@ -0,0 +1,357 @@
|
||||||
|
"""TTL-LRU cache for compiled SurfSense deep agents.
|
||||||
|
|
||||||
|
Why this exists
|
||||||
|
---------------
|
||||||
|
|
||||||
|
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
|
||||||
|
turn:
|
||||||
|
|
||||||
|
1. Discover connectors & document types from Postgres (~50-200ms)
|
||||||
|
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
|
||||||
|
3. Compose the system prompt
|
||||||
|
4. Construct ~15 middleware instances (CPU)
|
||||||
|
5. Eagerly compile the general-purpose subagent
|
||||||
|
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
|
||||||
|
which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure
|
||||||
|
CPU work)
|
||||||
|
6. Compile the outer LangGraph
|
||||||
|
|
||||||
|
For a single thread, all six steps produce the SAME object on every turn
|
||||||
|
unless the user has changed their LLM config, toggled a feature flag,
|
||||||
|
added a connector, etc. The right answer is to compile ONCE per
|
||||||
|
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
|
||||||
|
every subsequent turn on the same thread.
|
||||||
|
|
||||||
|
Why a per-thread key (not a global pool)
|
||||||
|
----------------------------------------
|
||||||
|
|
||||||
|
Most middleware in the SurfSense stack captures per-thread state in
|
||||||
|
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
|
||||||
|
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
|
||||||
|
would silently leak state across users and threads. Keying the cache on
|
||||||
|
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
|
||||||
|
turns on the same thread without changing any middleware's behavior.
|
||||||
|
|
||||||
|
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
|
||||||
|
(read via ``runtime.context``) so the cache can collapse to a single
|
||||||
|
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
|
||||||
|
then, per-thread keying is the only safe option.
|
||||||
|
|
||||||
|
Cache shape
|
||||||
|
-----------
|
||||||
|
|
||||||
|
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
|
||||||
|
minutes — matches a typical chat session). ``maxsize`` (default 256)
|
||||||
|
caps memory; LRU evicts least-recently-used on overflow.
|
||||||
|
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
|
||||||
|
cold misses on the same key wait for the first build instead of
|
||||||
|
building N times.
|
||||||
|
* Process-local: this is an in-memory cache. Multi-replica deployments
|
||||||
|
pay the build cost once per replica per key. That's fine; the working
|
||||||
|
set per replica is small (one entry per active thread on that replica).
|
||||||
|
|
||||||
|
Telemetry
|
||||||
|
---------
|
||||||
|
|
||||||
|
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
|
||||||
|
|
||||||
|
* ``hit`` — cache hit, microseconds-fast
|
||||||
|
* ``miss`` — first build for this key, includes build duration
|
||||||
|
* ``stale`` — entry was found but expired; rebuilt
|
||||||
|
* ``evict`` — LRU eviction (size-limited)
|
||||||
|
* ``size`` — current cache occupancy at lookup time
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API: signature helpers (cache key components)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def stable_hash(*parts: Any) -> str:
|
||||||
|
"""Compute a deterministic SHA1 of the str repr of ``parts``.
|
||||||
|
|
||||||
|
Used for cache key components that need a fixed-width representation
|
||||||
|
(system prompt, tool list, etc.). SHA1 is fine here — this is not a
|
||||||
|
security boundary, just a content fingerprint.
|
||||||
|
"""
|
||||||
|
h = hashlib.sha1(usedforsecurity=False)
|
||||||
|
for p in parts:
|
||||||
|
h.update(repr(p).encode("utf-8", errors="replace"))
|
||||||
|
h.update(b"\x1f") # ASCII unit separator between parts
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def tools_signature(
|
||||||
|
tools: list[Any] | tuple[Any, ...],
|
||||||
|
*,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
) -> str:
|
||||||
|
"""Hash the bound-tool surface for cache-key purposes.
|
||||||
|
|
||||||
|
The signature changes whenever:
|
||||||
|
|
||||||
|
* A tool is added or removed from the bound list (built-in toggles,
|
||||||
|
MCP tools loaded for the user changes, gating rules flip, etc.).
|
||||||
|
* The available connectors / document types for the search space
|
||||||
|
change (new connector added, last connector removed, new document
|
||||||
|
type indexed). Because :func:`get_connector_gated_tools` derives
|
||||||
|
``modified_disabled_tools`` from ``available_connectors``, the
|
||||||
|
tool surface is technically already covered — but we hash the
|
||||||
|
connector list separately so an empty-list "no tools changed"
|
||||||
|
situation still rotates the key when, say, the user re-adds a
|
||||||
|
connector that gates a tool we were already not exposing.
|
||||||
|
|
||||||
|
Stays stable across:
|
||||||
|
|
||||||
|
* Process restarts (tool names + descriptions are static).
|
||||||
|
* Different replicas (everyone gets the same hash for the same
|
||||||
|
inputs).
|
||||||
|
"""
|
||||||
|
tool_descriptors = sorted(
|
||||||
|
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
|
||||||
|
)
|
||||||
|
connectors = sorted(available_connectors or [])
|
||||||
|
doc_types = sorted(available_document_types or [])
|
||||||
|
return stable_hash(tool_descriptors, connectors, doc_types)
|
||||||
|
|
||||||
|
|
||||||
|
def flags_signature(flags: Any) -> str:
|
||||||
|
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
|
||||||
|
|
||||||
|
Frozen dataclasses are deterministically reprable, so a SHA1 of their
|
||||||
|
repr is a stable fingerprint. Restart safe (flags are read once at
|
||||||
|
process boot).
|
||||||
|
"""
|
||||||
|
return stable_hash(repr(flags))
|
||||||
|
|
||||||
|
|
||||||
|
def system_prompt_hash(system_prompt: str) -> str:
|
||||||
|
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
|
||||||
|
return hashlib.sha1(
|
||||||
|
system_prompt.encode("utf-8", errors="replace"),
|
||||||
|
usedforsecurity=False,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Cache implementation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Entry:
|
||||||
|
value: Any
|
||||||
|
created_at: float
|
||||||
|
last_used_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class _AgentCache:
|
||||||
|
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
|
||||||
|
|
||||||
|
NOT THREAD-SAFE in the multithreading sense — designed for a single
|
||||||
|
asyncio event loop. Uvicorn runs one event loop per worker process,
|
||||||
|
so this is fine; multi-worker deployments simply each maintain their
|
||||||
|
own cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
|
||||||
|
self._maxsize = maxsize
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
self._entries: OrderedDict[str, _Entry] = OrderedDict()
|
||||||
|
# One lock per key — guards "build" so concurrent cold misses on
|
||||||
|
# the same key wait for the first build instead of all racing.
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
def _now(self) -> float:
|
||||||
|
return time.monotonic()
|
||||||
|
|
||||||
|
def _is_fresh(self, entry: _Entry) -> bool:
|
||||||
|
return (self._now() - entry.created_at) < self._ttl
|
||||||
|
|
||||||
|
def _evict_if_full(self) -> None:
|
||||||
|
while len(self._entries) >= self._maxsize:
|
||||||
|
evicted_key, _ = self._entries.popitem(last=False)
|
||||||
|
self._locks.pop(evicted_key, None)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] evict key=%s reason=lru size=%d",
|
||||||
|
_short(evicted_key),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _touch(self, key: str, entry: _Entry) -> None:
|
||||||
|
entry.last_used_at = self._now()
|
||||||
|
self._entries.move_to_end(key, last=True)
|
||||||
|
|
||||||
|
async def get_or_build(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
*,
|
||||||
|
builder: Callable[[], Awaitable[Any]],
|
||||||
|
) -> Any:
|
||||||
|
"""Return the cached value for ``key`` or call ``builder()`` to make it.
|
||||||
|
|
||||||
|
``builder`` MUST be idempotent — concurrent cold misses on the
|
||||||
|
same key collapse to a single ``builder()`` call (the others
|
||||||
|
wait on the in-flight lock and observe the populated entry on
|
||||||
|
wake).
|
||||||
|
"""
|
||||||
|
# Fast path: hot hit.
|
||||||
|
entry = self._entries.get(key)
|
||||||
|
if entry is not None and self._is_fresh(entry):
|
||||||
|
self._touch(key, entry)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] hit key=%s age=%.1fs size=%d",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
# Stale entry — drop it; rebuild below.
|
||||||
|
if entry is not None and not self._is_fresh(entry):
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
self._ttl,
|
||||||
|
)
|
||||||
|
self._entries.pop(key, None)
|
||||||
|
|
||||||
|
# Slow path: serialize concurrent misses for the same key.
|
||||||
|
lock = self._locks.setdefault(key, asyncio.Lock())
|
||||||
|
async with lock:
|
||||||
|
# Double-check after acquiring the lock — another waiter may
|
||||||
|
# have populated the entry while we slept.
|
||||||
|
entry = self._entries.get(key)
|
||||||
|
if entry is not None and self._is_fresh(entry):
|
||||||
|
self._touch(key, entry)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
|
||||||
|
_short(key),
|
||||||
|
self._now() - entry.created_at,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
value = await builder()
|
||||||
|
except BaseException:
|
||||||
|
# Don't cache failed builds; let the next caller retry.
|
||||||
|
_perf_log.warning(
|
||||||
|
"[agent_cache] build_failed key=%s elapsed=%.3fs",
|
||||||
|
_short(key),
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
|
||||||
|
# Insert + evict.
|
||||||
|
self._evict_if_full()
|
||||||
|
now = self._now()
|
||||||
|
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
|
||||||
|
self._entries.move_to_end(key, last=True)
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] miss key=%s build=%.3fs size=%d",
|
||||||
|
_short(key),
|
||||||
|
elapsed,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def invalidate(self, key: str) -> bool:
|
||||||
|
"""Drop a single entry; return True if anything was removed."""
|
||||||
|
removed = self._entries.pop(key, None) is not None
|
||||||
|
self._locks.pop(key, None)
|
||||||
|
if removed:
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] invalidate key=%s size=%d",
|
||||||
|
_short(key),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def invalidate_prefix(self, prefix: str) -> int:
|
||||||
|
"""Drop every entry whose key starts with ``prefix``. Returns count."""
|
||||||
|
keys = [k for k in self._entries if k.startswith(prefix)]
|
||||||
|
for k in keys:
|
||||||
|
self._entries.pop(k, None)
|
||||||
|
self._locks.pop(k, None)
|
||||||
|
if keys:
|
||||||
|
_perf_log.info(
|
||||||
|
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
|
||||||
|
_short(prefix),
|
||||||
|
len(keys),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return len(keys)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
n = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
self._locks.clear()
|
||||||
|
if n:
|
||||||
|
_perf_log.info("[agent_cache] clear removed=%d", n)
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"size": len(self._entries),
|
||||||
|
"maxsize": self._maxsize,
|
||||||
|
"ttl_seconds": self._ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _short(key: str, n: int = 16) -> str:
|
||||||
|
"""Truncate keys for log lines so they don't blow up log volume."""
|
||||||
|
return key if len(key) <= n else f"{key[:n]}..."
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
|
||||||
|
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
|
||||||
|
|
||||||
|
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache() -> _AgentCache:
|
||||||
|
"""Return the process-wide compiled-agent cache singleton."""
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
|
||||||
|
"""Replace the singleton with a fresh cache. Tests only."""
|
||||||
|
global _cache
|
||||||
|
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"flags_signature",
|
||||||
|
"get_cache",
|
||||||
|
"reload_for_tests",
|
||||||
|
"stable_hash",
|
||||||
|
"system_prompt_hash",
|
||||||
|
"tools_signature",
|
||||||
|
]
|
||||||
|
|
@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
|
||||||
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
||||||
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
||||||
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
||||||
summarisation, prompt-caching, etc.).
|
summarisation, etc.). Prompt caching is configured at LLM-build time via
|
||||||
|
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
|
||||||
|
than as a middleware.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -33,12 +35,18 @@ from langchain.agents.middleware import (
|
||||||
TodoListMiddleware,
|
TodoListMiddleware,
|
||||||
ToolCallLimitMiddleware,
|
ToolCallLimitMiddleware,
|
||||||
)
|
)
|
||||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.agent_cache import (
|
||||||
|
flags_signature,
|
||||||
|
get_cache,
|
||||||
|
stable_hash,
|
||||||
|
system_prompt_hash,
|
||||||
|
tools_signature,
|
||||||
|
)
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
|
|
@ -52,6 +60,7 @@ from app.agents.new_chat.middleware import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
DoomLoopMiddleware,
|
DoomLoopMiddleware,
|
||||||
FileIntentMiddleware,
|
FileIntentMiddleware,
|
||||||
|
FlattenSystemMessageMiddleware,
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
KnowledgePriorityMiddleware,
|
KnowledgePriorityMiddleware,
|
||||||
KnowledgeTreeMiddleware,
|
KnowledgeTreeMiddleware,
|
||||||
|
|
@ -74,6 +83,7 @@ from app.agents.new_chat.plugin_loader import (
|
||||||
load_allowed_plugin_names_from_env,
|
load_allowed_plugin_names_from_env,
|
||||||
load_plugin_middlewares,
|
load_plugin_middlewares,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.agents.new_chat.subagents import build_specialized_subagents
|
from app.agents.new_chat.subagents import build_specialized_subagents
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
|
|
@ -94,6 +104,39 @@ from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_prompt_model_name(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve the model id to feed to provider-variant detection.
|
||||||
|
|
||||||
|
Preference order (matches the established idiom in
|
||||||
|
``llm_router_service.py`` — see ``params.get("base_model") or
|
||||||
|
params.get("model", "")`` usages there):
|
||||||
|
|
||||||
|
1. ``agent_config.litellm_params["base_model"]`` — required for Azure
|
||||||
|
deployments where ``model_name`` is the deployment slug, not the
|
||||||
|
underlying family. Without this, a deployment named e.g.
|
||||||
|
``"prod-chat-001"`` would silently miss every provider regex.
|
||||||
|
2. ``agent_config.model_name`` — the user's configured model id.
|
||||||
|
3. ``getattr(llm, "model", None)`` — fallback for direct callers that
|
||||||
|
don't supply an ``AgentConfig`` (currently a defensive path; all
|
||||||
|
production callers pass ``agent_config``).
|
||||||
|
|
||||||
|
Returns ``None`` when nothing is available; ``compose_system_prompt``
|
||||||
|
treats that as the ``"default"`` variant (no provider block emitted).
|
||||||
|
"""
|
||||||
|
if agent_config is not None:
|
||||||
|
params = agent_config.litellm_params or {}
|
||||||
|
base_model = params.get("base_model")
|
||||||
|
if isinstance(base_model, str) and base_model.strip():
|
||||||
|
return base_model
|
||||||
|
if agent_config.model_name:
|
||||||
|
return agent_config.model_name
|
||||||
|
return getattr(llm, "model", None)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Connector Type Mapping
|
# Connector Type Mapping
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -279,6 +322,14 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
|
|
||||||
|
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
|
||||||
|
# build-time call in ``llm_config.py``; this run merely adds
|
||||||
|
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
|
||||||
|
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
|
||||||
|
# None or the provider is non-OpenAI-family.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||||
|
|
||||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
backend_resolver = build_backend_resolver(
|
backend_resolver = build_backend_resolver(
|
||||||
filesystem_selection,
|
filesystem_selection,
|
||||||
|
|
@ -287,23 +338,39 @@ async def create_surfsense_deep_agent(
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space.
|
||||||
|
#
|
||||||
|
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
|
||||||
|
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
|
||||||
|
# SQLAlchemy explicitly forbids concurrent operations on the same session
|
||||||
|
# ("This session is provisioning a new connection; concurrent operations
|
||||||
|
# are not permitted on the same session"). The Phase 1.4 in-process TTL
|
||||||
|
# cache in ``connector_service`` already collapses the warm path to a
|
||||||
|
# near-zero pair of dict lookups, so sequential awaits cost nothing in
|
||||||
|
# the common case while remaining correct on cold cache misses.
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
available_document_types: list[str] | None = None
|
available_document_types: list[str] | None = None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
connector_types = await connector_service.get_available_connectors(
|
try:
|
||||||
search_space_id
|
connector_types_result = await connector_service.get_available_connectors(
|
||||||
)
|
search_space_id
|
||||||
if connector_types:
|
)
|
||||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
if connector_types_result:
|
||||||
|
available_connectors = _map_connectors_to_searchable_types(
|
||||||
|
connector_types_result
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to discover available connectors: %s", e)
|
||||||
|
|
||||||
available_document_types = await connector_service.get_available_document_types(
|
try:
|
||||||
search_space_id
|
available_document_types = (
|
||||||
)
|
await connector_service.get_available_document_types(search_space_id)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.warning("Failed to discover available document types: %s", e)
|
||||||
|
except Exception as e: # pragma: no cover - defensive outer guard
|
||||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||||
|
|
@ -398,6 +465,7 @@ async def create_surfsense_deep_agent(
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
mcp_connector_tools=_mcp_connector_tools,
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
|
|
@ -405,6 +473,7 @@ async def create_surfsense_deep_agent(
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
mcp_connector_tools=_mcp_connector_tools,
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -424,29 +493,77 @@ async def create_surfsense_deep_agent(
|
||||||
# entire middleware build + main-graph compile into a single
|
# entire middleware build + main-graph compile into a single
|
||||||
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
||||||
# event loop stays responsive.
|
# event loop stays responsive.
|
||||||
|
#
|
||||||
|
# PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
|
||||||
|
# on every per-request value that any middleware in the stack closes
|
||||||
|
# over in ``__init__`` — drop one and you risk leaking state across
|
||||||
|
# threads. Hits collapse this whole block to a microsecond lookup;
|
||||||
|
# misses pay the original CPU cost AND populate the cache.
|
||||||
|
config_id = agent_config.config_id if agent_config is not None else None
|
||||||
|
|
||||||
|
async def _build_agent() -> Any:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
_build_compiled_agent_blocking,
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
final_system_prompt=final_system_prompt,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
visibility=visibility,
|
||||||
|
anon_session_id=anon_session_id,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
# ``mentioned_document_ids`` is consumed by
|
||||||
|
# ``KnowledgePriorityMiddleware`` per turn via
|
||||||
|
# ``runtime.context`` (Phase 1.5). We still pass the
|
||||||
|
# caller-provided list here for the legacy fallback path
|
||||||
|
# (cache disabled / context not propagated) — the middleware
|
||||||
|
# drains its own copy after the first read so a cached graph
|
||||||
|
# never replays stale mentions.
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
max_input_tokens=_max_input_tokens,
|
||||||
|
flags=_flags,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await asyncio.to_thread(
|
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
|
||||||
_build_compiled_agent_blocking,
|
# Cache key components — order matters only for human readability;
|
||||||
llm=llm,
|
# the resulting hash is what's stored. Every component must
|
||||||
tools=tools,
|
# rotate on a real shape change AND stay stable across identical
|
||||||
final_system_prompt=final_system_prompt,
|
# invocations.
|
||||||
backend_resolver=backend_resolver,
|
cache_key = stable_hash(
|
||||||
filesystem_mode=filesystem_selection.mode,
|
"v1", # schema version of the key — bump if components change
|
||||||
search_space_id=search_space_id,
|
config_id,
|
||||||
user_id=user_id,
|
thread_id,
|
||||||
thread_id=thread_id,
|
user_id,
|
||||||
visibility=visibility,
|
search_space_id,
|
||||||
anon_session_id=anon_session_id,
|
visibility,
|
||||||
available_connectors=available_connectors,
|
filesystem_selection.mode,
|
||||||
available_document_types=available_document_types,
|
anon_session_id,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
tools_signature(
|
||||||
max_input_tokens=_max_input_tokens,
|
tools,
|
||||||
flags=_flags,
|
available_connectors=available_connectors,
|
||||||
checkpointer=checkpointer,
|
available_document_types=available_document_types,
|
||||||
)
|
),
|
||||||
|
flags_signature(_flags),
|
||||||
|
system_prompt_hash(final_system_prompt),
|
||||||
|
_max_input_tokens,
|
||||||
|
# ``mentioned_document_ids`` deliberately omitted — middleware
|
||||||
|
# reads it from ``runtime.context`` (Phase 1.5).
|
||||||
|
)
|
||||||
|
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
|
||||||
|
else:
|
||||||
|
agent = await _build_agent()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
"[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
|
"on"
|
||||||
|
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
|
||||||
|
else "off",
|
||||||
)
|
)
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -568,7 +685,6 @@ def _build_compiled_agent_blocking(
|
||||||
),
|
),
|
||||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||||
|
|
@ -998,6 +1114,14 @@ def _build_compiled_agent_blocking(
|
||||||
noop_mw,
|
noop_mw,
|
||||||
retry_mw,
|
retry_mw,
|
||||||
fallback_mw,
|
fallback_mw,
|
||||||
|
# Coalesce a multi-text-block system message into one block
|
||||||
|
# immediately before the model call. Sits innermost on the
|
||||||
|
# system-message-mutation chain so it observes every appender
|
||||||
|
# (todo / filesystem / skills / subagents …) and prevents
|
||||||
|
# OpenRouter→Anthropic from redistributing ``cache_control``
|
||||||
|
# across N blocks and tripping Anthropic's 4-breakpoint cap.
|
||||||
|
# See ``middleware/flatten_system.py`` for full rationale.
|
||||||
|
FlattenSystemMessageMiddleware(),
|
||||||
# Tool-call repair must run after model emits but before
|
# Tool-call repair must run after model emits but before
|
||||||
# permission / dedup / doom-loop interpret the calls.
|
# permission / dedup / doom-loop interpret the calls.
|
||||||
repair_mw,
|
repair_mw,
|
||||||
|
|
@ -1010,12 +1134,12 @@ def _build_compiled_agent_blocking(
|
||||||
action_log_mw,
|
action_log_mw,
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||||
# Plugin slot — sits just before AnthropicCache so plugin-side
|
# Plugin slot — sits at the tail so plugin-side transforms see the
|
||||||
# transforms see the final tool result and run before any
|
# final tool result. Prompt caching is now applied at LLM build time
|
||||||
# caching heuristics. Multiple plugins in declared order; loader
|
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
|
||||||
# filtered by the admin allowlist already.
|
# caching middleware is needed here. Multiple plugins run in declared
|
||||||
|
# order; loader filtered by the admin allowlist already.
|
||||||
*plugin_middlewares,
|
*plugin_middlewares,
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
]
|
||||||
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,25 @@
|
||||||
"""
|
"""
|
||||||
Context schema definitions for SurfSense agents.
|
Context schema definitions for SurfSense agents.
|
||||||
|
|
||||||
This module defines the custom state schema used by the SurfSense deep agent.
|
This module defines the per-invocation context object passed to the SurfSense
|
||||||
|
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
|
||||||
|
|
||||||
|
The agent's compiled graph is the same across invocations (and cached by
|
||||||
|
``agent_cache``), so anything that varies per turn — the user mentions a
|
||||||
|
specific document, the front-end issues a unique ``request_id``, etc. —
|
||||||
|
MUST live on this context object instead of being captured into a
|
||||||
|
middleware ``__init__`` closure. Middlewares read fields back via
|
||||||
|
``runtime.context.<field>``; tools read them via ``runtime.context``.
|
||||||
|
|
||||||
|
This object is read inside both ``KnowledgePriorityMiddleware`` (for
|
||||||
|
``mentioned_document_ids``) and any future middleware that needs
|
||||||
|
per-request state without invalidating the compiled-agent cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import NotRequired, TypedDict
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
|
||||||
class FileOperationContractState(TypedDict):
|
class FileOperationContractState(TypedDict):
|
||||||
|
|
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseContextSchema(TypedDict):
|
@dataclass
|
||||||
|
class SurfSenseContextSchema:
|
||||||
"""
|
"""
|
||||||
Custom state schema for the SurfSense deep agent.
|
Per-invocation context for the SurfSense deep agent.
|
||||||
|
|
||||||
This extends the default agent state with custom fields.
|
Defaults are chosen so the dataclass can be safely default-constructed
|
||||||
The default state already includes:
|
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
|
||||||
- messages: Conversation history
|
context is supplied — see ``langgraph.runtime.Runtime``). All fields
|
||||||
- todos: Task list from TodoListMiddleware
|
are optional; consumers must None-check before reading.
|
||||||
- files: Virtual filesystem from FilesystemMiddleware
|
|
||||||
|
|
||||||
We're adding fields needed for knowledge base search:
|
Phase 1.5 fields:
|
||||||
- search_space_id: The user's search space ID
|
search_space_id: Search space the request is scoped to.
|
||||||
- db_session: Database session (injected at runtime)
|
mentioned_document_ids: KB documents the user @-mentioned this turn.
|
||||||
- connector_service: Connector service instance (injected at runtime)
|
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||||
|
list. Stays out of the compiled-agent cache key — that's the
|
||||||
|
whole point of putting it here.
|
||||||
|
file_operation_contract: One-shot file operation contract emitted
|
||||||
|
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||||
|
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||||
|
task; populated for telemetry.
|
||||||
|
|
||||||
|
Phase 2 will extend with: thread_id, user_id, visibility,
|
||||||
|
filesystem_mode, anon_session_id, available_connectors,
|
||||||
|
available_document_types, created_by_id (everything currently captured
|
||||||
|
by middleware ``__init__`` closures).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int | None = None
|
||||||
file_operation_contract: NotRequired[FileOperationContractState]
|
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||||
turn_id: NotRequired[str]
|
file_operation_contract: FileOperationContractState | None = None
|
||||||
request_id: NotRequired[str]
|
turn_id: str | None = None
|
||||||
# These are runtime-injected and won't be serialized
|
request_id: str | None = None
|
||||||
# db_session and connector_service are passed when invoking the agent
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
|
||||||
|
|
||||||
These flags gate the newer agent middleware (some ported from OpenCode,
|
These flags gate the newer agent middleware (some ported from OpenCode,
|
||||||
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
||||||
SurfSense-native). They follow a "default-OFF for risky things,
|
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
||||||
default-ON for safe upgrades, master kill-switch for everything new" model.
|
image updates work even when older installs do not have newly introduced
|
||||||
|
environment variables. Risky/experimental integrations stay default OFF,
|
||||||
|
and the master kill-switch can still disable everything new.
|
||||||
|
|
||||||
All new middleware checks its flag at agent build time. If the master
|
All new middleware checks its flag at agent build time. If the master
|
||||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||||
|
|
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
|
||||||
Local development (recommended for trying everything except doom-loop / selector):
|
Defaults:
|
||||||
|
|
||||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||||
|
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||||
|
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||||
|
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
|
@ -60,32 +65,28 @@ class AgentFeatureFlags:
|
||||||
disable_new_agent_stack: bool = False
|
disable_new_agent_stack: bool = False
|
||||||
|
|
||||||
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
||||||
enable_context_editing: bool = False
|
enable_context_editing: bool = True
|
||||||
enable_compaction_v2: bool = False
|
enable_compaction_v2: bool = True
|
||||||
enable_retry_after: bool = False
|
enable_retry_after: bool = True
|
||||||
enable_model_fallback: bool = False
|
enable_model_fallback: bool = False
|
||||||
enable_model_call_limit: bool = False
|
enable_model_call_limit: bool = True
|
||||||
enable_tool_call_limit: bool = False
|
enable_tool_call_limit: bool = True
|
||||||
enable_tool_call_repair: bool = False
|
enable_tool_call_repair: bool = True
|
||||||
enable_doom_loop: bool = (
|
enable_doom_loop: bool = True
|
||||||
False # Default OFF until UI handles permission='doom_loop'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safety — permissions, concurrency, tool-set narrowing
|
# Safety — permissions, concurrency, tool-set narrowing
|
||||||
enable_permission: bool = False # Default OFF for first deploy
|
enable_permission: bool = True
|
||||||
enable_busy_mutex: bool = False
|
enable_busy_mutex: bool = True
|
||||||
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||||
|
|
||||||
# Skills + subagents
|
# Skills + subagents
|
||||||
enable_skills: bool = False
|
enable_skills: bool = True
|
||||||
enable_specialized_subagents: bool = False
|
enable_specialized_subagents: bool = True
|
||||||
enable_kb_planner_runnable: bool = False
|
enable_kb_planner_runnable: bool = True
|
||||||
|
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log: bool = False
|
enable_action_log: bool = True
|
||||||
enable_revert_route: bool = (
|
enable_revert_route: bool = True
|
||||||
False # Backend ships before UI; route returns 503 until this flips
|
|
||||||
)
|
|
||||||
|
|
||||||
# Streaming parity v2 — opt in to LangChain's structured
|
# Streaming parity v2 — opt in to LangChain's structured
|
||||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||||
|
|
@ -94,7 +95,7 @@ class AgentFeatureFlags:
|
||||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||||
# ship unconditionally because they're forward-compatible.
|
# ship unconditionally because they're forward-compatible.
|
||||||
enable_stream_parity_v2: bool = False
|
enable_stream_parity_v2: bool = True
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
@ -102,6 +103,41 @@ class AgentFeatureFlags:
|
||||||
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||||
enable_otel: bool = False
|
enable_otel: bool = False
|
||||||
|
|
||||||
|
# Performance — compiled-agent cache (Phase 1 + Phase 2).
|
||||||
|
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
|
||||||
|
# graph if the cache key matches (LLM config + thread + tool surface +
|
||||||
|
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
|
||||||
|
# wall clock from ~4-5s to <50µs on cache hits.
|
||||||
|
#
|
||||||
|
# SAFETY (Phase 2 unblocked this default-on):
|
||||||
|
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
|
||||||
|
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
|
||||||
|
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
|
||||||
|
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
|
||||||
|
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
|
||||||
|
# short-lived ``AsyncSession`` instances per call via
|
||||||
|
# :data:`async_session_maker`. The factory still accepts ``db_session``
|
||||||
|
# for registry compatibility but ``del``'s it immediately — see any
|
||||||
|
# of those files' factory docstrings for the rationale. The ``llm``
|
||||||
|
# closure is per-(provider, model, config_id) which is already in
|
||||||
|
# the cache key, so the LLM is safe to share across cached hits of
|
||||||
|
# the same key. The KB priority middleware reads
|
||||||
|
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
|
||||||
|
# not its constructor closure, so the same compiled agent serves
|
||||||
|
# turns with different mention lists correctly.
|
||||||
|
#
|
||||||
|
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
|
||||||
|
# environment if a regression surfaces. The path is exercised by
|
||||||
|
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
|
||||||
|
enable_agent_cache: bool = True
|
||||||
|
# Phase 1 (deferred — measure first): pre-build & share the
|
||||||
|
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
|
||||||
|
# misses. Only helps when the outer cache MISSES (cache hits already
|
||||||
|
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
|
||||||
|
# until we have data showing cold misses are frequent enough to
|
||||||
|
# justify the extra global state.
|
||||||
|
enable_agent_cache_share_gp_subagent: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls) -> AgentFeatureFlags:
|
def from_env(cls) -> AgentFeatureFlags:
|
||||||
"""Read flags from environment.
|
"""Read flags from environment.
|
||||||
|
|
@ -115,48 +151,76 @@ class AgentFeatureFlags:
|
||||||
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||||
"middleware is forced OFF for this build."
|
"middleware is forced OFF for this build."
|
||||||
)
|
)
|
||||||
return cls(disable_new_agent_stack=True)
|
return cls(
|
||||||
|
disable_new_agent_stack=True,
|
||||||
|
enable_context_editing=False,
|
||||||
|
enable_compaction_v2=False,
|
||||||
|
enable_retry_after=False,
|
||||||
|
enable_model_fallback=False,
|
||||||
|
enable_model_call_limit=False,
|
||||||
|
enable_tool_call_limit=False,
|
||||||
|
enable_tool_call_repair=False,
|
||||||
|
enable_doom_loop=False,
|
||||||
|
enable_permission=False,
|
||||||
|
enable_busy_mutex=False,
|
||||||
|
enable_llm_tool_selector=False,
|
||||||
|
enable_skills=False,
|
||||||
|
enable_specialized_subagents=False,
|
||||||
|
enable_kb_planner_runnable=False,
|
||||||
|
enable_action_log=False,
|
||||||
|
enable_revert_route=False,
|
||||||
|
enable_stream_parity_v2=False,
|
||||||
|
enable_plugin_loader=False,
|
||||||
|
enable_otel=False,
|
||||||
|
enable_agent_cache=False,
|
||||||
|
enable_agent_cache_share_gp_subagent=False,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
disable_new_agent_stack=False,
|
disable_new_agent_stack=False,
|
||||||
# Agent quality
|
# Agent quality
|
||||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
|
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
|
||||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
|
||||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
|
||||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||||
enable_model_call_limit=_env_bool(
|
enable_model_call_limit=_env_bool(
|
||||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
|
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
|
||||||
),
|
),
|
||||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
|
||||||
enable_tool_call_repair=_env_bool(
|
enable_tool_call_repair=_env_bool(
|
||||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
|
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
|
||||||
),
|
),
|
||||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
|
||||||
# Safety
|
# Safety
|
||||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
|
||||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
|
||||||
enable_llm_tool_selector=_env_bool(
|
enable_llm_tool_selector=_env_bool(
|
||||||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||||
),
|
),
|
||||||
# Skills + subagents
|
# Skills + subagents
|
||||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
|
||||||
enable_specialized_subagents=_env_bool(
|
enable_specialized_subagents=_env_bool(
|
||||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
|
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
|
||||||
),
|
),
|
||||||
enable_kb_planner_runnable=_env_bool(
|
enable_kb_planner_runnable=_env_bool(
|
||||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
|
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||||
),
|
),
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||||
# Streaming parity v2
|
# Streaming parity v2
|
||||||
enable_stream_parity_v2=_env_bool(
|
enable_stream_parity_v2=_env_bool(
|
||||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||||
),
|
),
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
# Observability
|
# Observability
|
||||||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||||
|
# Performance
|
||||||
|
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
|
||||||
|
enable_agent_cache_share_gp_subagent=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def any_new_middleware_enabled(self) -> bool:
|
def any_new_middleware_enabled(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from litellm import get_model_info
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.services.llm_router_service import (
|
from app.services.llm_router_service import (
|
||||||
AUTO_MODE_ID,
|
AUTO_MODE_ID,
|
||||||
ChatLiteLLMRouter,
|
ChatLiteLLMRouter,
|
||||||
|
|
@ -89,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction
|
# Provider mapping for LiteLLM model string construction.
|
||||||
PROVIDER_MAP = {
|
#
|
||||||
"OPENAI": "openai",
|
# Single source of truth lives in
|
||||||
"ANTHROPIC": "anthropic",
|
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||||
"GROQ": "groq",
|
# runs during ``app.config`` class-body init) can resolve provider
|
||||||
"COHERE": "cohere",
|
# prefixes without dragging the agent / tools tree into module load
|
||||||
"GOOGLE": "gemini",
|
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||||
"OLLAMA": "ollama_chat",
|
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||||
"MISTRAL": "mistral",
|
# tests) keep working unchanged.
|
||||||
"AZURE_OPENAI": "azure",
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
"OPENROUTER": "openrouter",
|
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||||
"XAI": "xai",
|
)
|
||||||
"BEDROCK": "bedrock",
|
|
||||||
"VERTEX_AI": "vertex_ai",
|
|
||||||
"TOGETHER_AI": "together_ai",
|
|
||||||
"FIREWORKS_AI": "fireworks_ai",
|
|
||||||
"DEEPSEEK": "openai",
|
|
||||||
"ALIBABA_QWEN": "openai",
|
|
||||||
"MOONSHOT": "openai",
|
|
||||||
"ZHIPU": "openai",
|
|
||||||
"GITHUB_MODELS": "github",
|
|
||||||
"REPLICATE": "replicate",
|
|
||||||
"PERPLEXITY": "perplexity",
|
|
||||||
"ANYSCALE": "anyscale",
|
|
||||||
"DEEPINFRA": "deepinfra",
|
|
||||||
"CEREBRAS": "cerebras",
|
|
||||||
"SAMBANOVA": "sambanova",
|
|
||||||
"AI21": "ai21",
|
|
||||||
"CLOUDFLARE": "cloudflare",
|
|
||||||
"DATABRICKS": "databricks",
|
|
||||||
"COMETAPI": "cometapi",
|
|
||||||
"HUGGINGFACE": "huggingface",
|
|
||||||
"MINIMAX": "openai",
|
|
||||||
"CUSTOM": "custom",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||||
|
|
@ -177,6 +155,17 @@ class AgentConfig:
|
||||||
anonymous_enabled: bool = False
|
anonymous_enabled: bool = False
|
||||||
quota_reserve_tokens: int | None = None
|
quota_reserve_tokens: int | None = None
|
||||||
|
|
||||||
|
# Capability flag: best-effort True for the chat selector / catalog.
|
||||||
|
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||||
|
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||||
|
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||||
|
# is the conservative-allow stance — the streaming-task safety net
|
||||||
|
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||||
|
# actually blocks a request. Setting this to False here without an
|
||||||
|
# authoritative source would silently hide vision-capable models
|
||||||
|
# (the regression we're fixing).
|
||||||
|
supports_image_input: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_auto_mode(cls) -> "AgentConfig":
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
"""
|
"""
|
||||||
|
|
@ -202,6 +191,12 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# Auto routes across the configured pool, which usually
|
||||||
|
# contains at least one vision-capable deployment; the router
|
||||||
|
# will surface a 404 from a non-vision deployment as a normal
|
||||||
|
# ``allowed_fails`` event and fail over rather than blocking
|
||||||
|
# the request outright.
|
||||||
|
supports_image_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -215,10 +210,24 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
return cls(
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
provider=config.provider.value
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
provider_value = (
|
||||||
|
config.provider.value
|
||||||
if hasattr(config.provider, "value")
|
if hasattr(config.provider, "value")
|
||||||
else str(config.provider),
|
else str(config.provider)
|
||||||
|
)
|
||||||
|
litellm_params = config.litellm_params or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
provider=provider_value,
|
||||||
model_name=config.model_name,
|
model_name=config.model_name,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
api_base=config.api_base,
|
api_base=config.api_base,
|
||||||
|
|
@ -234,6 +243,16 @@ class AgentConfig:
|
||||||
is_premium=False,
|
is_premium=False,
|
||||||
anonymous_enabled=False,
|
anonymous_enabled=False,
|
||||||
quota_reserve_tokens=None,
|
quota_reserve_tokens=None,
|
||||||
|
# BYOK rows have no operator-curated capability flag, so we
|
||||||
|
# ask LiteLLM (default-allow on unknown). The streaming
|
||||||
|
# safety net still blocks if the model is *explicitly*
|
||||||
|
# marked text-only.
|
||||||
|
supports_image_input=derive_supports_image_input(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -252,15 +271,46 @@ class AgentConfig:
|
||||||
Returns:
|
Returns:
|
||||||
AgentConfig instance
|
AgentConfig instance
|
||||||
"""
|
"""
|
||||||
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
# Get system instructions from YAML, default to empty string
|
# Get system instructions from YAML, default to empty string
|
||||||
system_instructions = yaml_config.get("system_instructions", "")
|
system_instructions = yaml_config.get("system_instructions", "")
|
||||||
|
|
||||||
|
provider = yaml_config.get("provider", "").upper()
|
||||||
|
model_name = yaml_config.get("model_name", "")
|
||||||
|
custom_provider = yaml_config.get("custom_provider")
|
||||||
|
litellm_params = yaml_config.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||||
|
# OpenRouter modalities. The YAML loader already populates this
|
||||||
|
# field, but this method is also called from
|
||||||
|
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||||
|
# so we re-derive here for safety. The bool() coercion preserves
|
||||||
|
# the loader's behaviour for explicit ``true`` / ``false``
|
||||||
|
# strings that PyYAML may surface.
|
||||||
|
if "supports_image_input" in yaml_config:
|
||||||
|
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||||
|
else:
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
provider=yaml_config.get("provider", "").upper(),
|
provider=provider,
|
||||||
model_name=yaml_config.get("model_name", ""),
|
model_name=model_name,
|
||||||
api_key=yaml_config.get("api_key", ""),
|
api_key=yaml_config.get("api_key", ""),
|
||||||
api_base=yaml_config.get("api_base"),
|
api_base=yaml_config.get("api_base"),
|
||||||
custom_provider=yaml_config.get("custom_provider"),
|
custom_provider=custom_provider,
|
||||||
litellm_params=yaml_config.get("litellm_params"),
|
litellm_params=yaml_config.get("litellm_params"),
|
||||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||||
system_instructions=system_instructions if system_instructions else None,
|
system_instructions=system_instructions if system_instructions else None,
|
||||||
|
|
@ -275,6 +325,7 @@ class AgentConfig:
|
||||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||||
|
supports_image_input=supports_image_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -494,6 +545,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||||
|
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||||
|
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||||
|
# in a structured form, so we set only the universal injection points.
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -518,7 +574,16 @@ def create_chat_litellm_from_agent_config(
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return get_auto_mode_llm()
|
router_llm = get_auto_mode_llm()
|
||||||
|
if router_llm is not None:
|
||||||
|
# Universal cache_control_injection_points only — auto-mode
|
||||||
|
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||||
|
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||||
|
# would strip them at the provider boundary anyway, but
|
||||||
|
# there's no point setting them when we don't know the
|
||||||
|
# destination.
|
||||||
|
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||||
|
return router_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -549,4 +614,9 @@ def create_chat_litellm_from_agent_config(
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||||
|
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||||
|
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||||
|
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||||
return llm
|
return llm
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
|
||||||
from app.agents.new_chat.middleware.filesystem import (
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.flatten_system import (
|
||||||
|
FlattenSystemMessageMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.kb_persistence import (
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
commit_staged_filesystem_state,
|
commit_staged_filesystem_state,
|
||||||
|
|
@ -61,6 +64,7 @@ __all__ = [
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
"DoomLoopMiddleware",
|
"DoomLoopMiddleware",
|
||||||
"FileIntentMiddleware",
|
"FileIntentMiddleware",
|
||||||
|
"FlattenSystemMessageMiddleware",
|
||||||
"KnowledgeBasePersistenceMiddleware",
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
"KnowledgePriorityMiddleware",
|
"KnowledgePriorityMiddleware",
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -58,6 +59,11 @@ class _ThreadLockManager:
|
||||||
weakref.WeakValueDictionary()
|
weakref.WeakValueDictionary()
|
||||||
)
|
)
|
||||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||||
|
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||||
|
self._cancel_attempt_count: dict[str, int] = {}
|
||||||
|
# Monotonic per-thread epoch used to prevent stale middleware
|
||||||
|
# teardown from releasing a newer turn's lock.
|
||||||
|
self._turn_epoch: dict[str, int] = {}
|
||||||
|
|
||||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||||
lock = self._locks.get(thread_id)
|
lock = self._locks.get(thread_id)
|
||||||
|
|
@ -76,14 +82,57 @@ class _ThreadLockManager:
|
||||||
def request_cancel(self, thread_id: str) -> bool:
|
def request_cancel(self, thread_id: str) -> bool:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is None:
|
if event is None:
|
||||||
return False
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
event.set()
|
event.set()
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||||
|
self._cancel_attempt_count[thread_id] = (
|
||||||
|
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
return bool(event and event.is_set())
|
||||||
|
|
||||||
|
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||||
|
if not self.is_cancel_requested(thread_id):
|
||||||
|
return None
|
||||||
|
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||||
|
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||||
|
return attempts, requested_at_ms
|
||||||
|
|
||||||
def reset(self, thread_id: str) -> None:
|
def reset(self, thread_id: str) -> None:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is not None:
|
if event is not None:
|
||||||
event.clear()
|
event.clear()
|
||||||
|
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||||
|
self._cancel_attempt_count.pop(thread_id, None)
|
||||||
|
|
||||||
|
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||||
|
self._turn_epoch[thread_id] = epoch
|
||||||
|
return epoch
|
||||||
|
|
||||||
|
def current_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
return self._turn_epoch.get(thread_id, 0)
|
||||||
|
|
||||||
|
def end_turn(self, thread_id: str) -> None:
|
||||||
|
"""Best-effort terminal cleanup for a thread turn.
|
||||||
|
|
||||||
|
This is intentionally idempotent and safe to call from outer stream
|
||||||
|
finally-blocks where middleware teardown might be skipped due to abort
|
||||||
|
or disconnect edge-cases.
|
||||||
|
"""
|
||||||
|
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||||
|
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||||
|
# retry that already acquired the lock for the same thread.
|
||||||
|
self.bump_turn_epoch(thread_id)
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is not None and lock.locked():
|
||||||
|
lock.release()
|
||||||
|
self.reset(thread_id)
|
||||||
|
|
||||||
def release(self, thread_id: str) -> bool:
|
def release(self, thread_id: str) -> bool:
|
||||||
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
|
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
|
||||||
|
|
@ -115,18 +164,28 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||||
|
|
||||||
|
|
||||||
def request_cancel(thread_id: str) -> bool:
|
def request_cancel(thread_id: str) -> bool:
|
||||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||||
return manager.request_cancel(thread_id)
|
return manager.request_cancel(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_requested(thread_id: str) -> bool:
|
||||||
|
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||||
|
return manager.is_cancel_requested(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||||
|
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||||
|
return manager.cancel_state(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def reset_cancel(thread_id: str) -> None:
|
def reset_cancel(thread_id: str) -> None:
|
||||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||||
manager.reset(thread_id)
|
manager.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def release_lock(thread_id: str) -> bool:
|
def end_turn(thread_id: str) -> None:
|
||||||
"""Force-release the per-thread busy lock; safe to call when not held."""
|
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||||
return manager.release(thread_id)
|
manager.end_turn(thread_id)
|
||||||
|
|
||||||
|
|
||||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
|
|
@ -151,10 +210,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._require_thread_id = require_thread_id
|
self._require_thread_id = require_thread_id
|
||||||
self.tools = []
|
self.tools = []
|
||||||
# Per-call locks owned by this middleware. We track them as
|
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||||
# an instance attribute so ``aafter_agent`` knows which lock
|
# only releases when its epoch still matches the manager's current
|
||||||
# to release.
|
# epoch for the thread, preventing stale unlock races.
|
||||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||||
|
|
@ -205,7 +264,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
raise BusyError(request_id=thread_id)
|
raise BusyError(request_id=thread_id)
|
||||||
await lock.acquire()
|
await lock.acquire()
|
||||||
self._held_locks[thread_id] = lock
|
epoch = manager.bump_turn_epoch(thread_id)
|
||||||
|
self._held_locks[thread_id] = (lock, epoch)
|
||||||
# Reset the cancel event so this turn starts fresh
|
# Reset the cancel event so this turn starts fresh
|
||||||
reset_cancel(thread_id)
|
reset_cancel(thread_id)
|
||||||
return None
|
return None
|
||||||
|
|
@ -219,8 +279,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
thread_id = self._thread_id(runtime)
|
thread_id = self._thread_id(runtime)
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
return None
|
return None
|
||||||
lock = self._held_locks.pop(thread_id, None)
|
held = self._held_locks.pop(thread_id, None)
|
||||||
if lock is not None and lock.locked():
|
if held is None:
|
||||||
|
return None
|
||||||
|
lock, held_epoch = held
|
||||||
|
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||||
|
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||||
|
# already advanced epoch). Do not touch current lock/cancel state.
|
||||||
|
return None
|
||||||
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
# Always clear cancel event between turns so a stale signal
|
# Always clear cancel event between turns so a stale signal
|
||||||
# doesn't leak into the next request.
|
# doesn't leak into the next request.
|
||||||
|
|
@ -251,9 +318,11 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BusyMutexMiddleware",
|
"BusyMutexMiddleware",
|
||||||
|
"end_turn",
|
||||||
"get_cancel_event",
|
"get_cancel_event",
|
||||||
|
"get_cancel_state",
|
||||||
|
"is_cancel_requested",
|
||||||
"manager",
|
"manager",
|
||||||
"release_lock",
|
|
||||||
"request_cancel",
|
"request_cancel",
|
||||||
"reset_cancel",
|
"reset_cancel",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,233 @@
|
||||||
|
r"""Coalesce multi-block system messages into a single text block.
|
||||||
|
|
||||||
|
Several middlewares in our deepagent stack each call
|
||||||
|
``append_to_system_message`` on the way down to the model
|
||||||
|
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||||
|
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||||
|
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||||
|
|
||||||
|
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||||
|
request**, and we configure 2 injection points
|
||||||
|
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||||
|
the prepended ``request.system_message``, this middleware is the
|
||||||
|
defensive partner: it guarantees that "the system block" is *one*
|
||||||
|
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||||
|
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||||
|
several breakpoints by spreading ``cache_control`` across multiple
|
||||||
|
text blocks of a multi-block system content.
|
||||||
|
|
||||||
|
Without flattening we used to see::
|
||||||
|
|
||||||
|
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||||
|
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||||
|
cache_control may be provided. Found 5."}}}
|
||||||
|
|
||||||
|
(Same error class documented in
|
||||||
|
https://github.com/BerriAI/litellm/issues/15696 and
|
||||||
|
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||||
|
in PR #15395 covers the litellm transformer but does not protect us
|
||||||
|
when the OpenRouter SaaS itself does the redistribution.)
|
||||||
|
|
||||||
|
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
|
||||||
|
the first injection point from ``role: system`` to ``index: 0``)
|
||||||
|
neutralises the *primary* cause of the same 400 — multiple
|
||||||
|
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||||
|
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||||
|
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||||
|
matcher. This middleware remains useful as defence-in-depth against
|
||||||
|
the multi-block redistribution path.
|
||||||
|
|
||||||
|
Placement: innermost on the system-message-mutation chain, after every
|
||||||
|
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||||
|
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||||
|
attempt sees a flattened payload. See ``chat_deepagent.py``.
|
||||||
|
|
||||||
|
Idempotent: a string-content system message is left untouched. A list
|
||||||
|
that contains anything other than plain text blocks (e.g. an image) is
|
||||||
|
also left untouched — those are rare on system messages and we'd lose
|
||||||
|
the non-text payload by joining.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ResponseT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||||
|
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||||
|
|
||||||
|
Returns ``None`` when the list contains anything that isn't a text
|
||||||
|
block we can safely concatenate (image, audio, file, non-standard
|
||||||
|
blocks, dicts with extra non-cache_control fields). The caller
|
||||||
|
leaves the original content untouched in that case rather than
|
||||||
|
silently dropping payload.
|
||||||
|
|
||||||
|
``cache_control`` on individual blocks is intentionally discarded —
|
||||||
|
the whole point of flattening is to let LiteLLM's
|
||||||
|
``cache_control_injection_points`` re-place a single breakpoint on
|
||||||
|
the resulting one-block system content.
|
||||||
|
"""
|
||||||
|
chunks: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
chunks.append(block)
|
||||||
|
continue
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
return None
|
||||||
|
if block.get("type") != "text":
|
||||||
|
return None
|
||||||
|
text = block.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return None
|
||||||
|
chunks.append(text)
|
||||||
|
return "\n\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _flattened_request(
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
) -> ModelRequest[ContextT] | None:
|
||||||
|
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||||
|
sys_msg = request.system_message
|
||||||
|
if sys_msg is None:
|
||||||
|
return None
|
||||||
|
content = sys_msg.content
|
||||||
|
if not isinstance(content, list) or len(content) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
flattened = _flatten_text_blocks(content)
|
||||||
|
if flattened is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_sys = SystemMessage(
|
||||||
|
content=flattened,
|
||||||
|
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||||
|
response_metadata=dict(sys_msg.response_metadata),
|
||||||
|
)
|
||||||
|
if sys_msg.id is not None:
|
||||||
|
new_sys.id = sys_msg.id
|
||||||
|
return request.override(system_message=new_sys)
|
||||||
|
|
||||||
|
|
||||||
|
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||||
|
"""One-line dump of cache_control-relevant request shape.
|
||||||
|
|
||||||
|
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||||
|
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||||
|
root cause is confirmed and a fix is in place.
|
||||||
|
"""
|
||||||
|
sys_msg = request.system_message
|
||||||
|
if sys_msg is None:
|
||||||
|
sys_shape = "none"
|
||||||
|
elif isinstance(sys_msg.content, str):
|
||||||
|
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||||
|
elif isinstance(sys_msg.content, list):
|
||||||
|
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||||
|
else:
|
||||||
|
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||||
|
|
||||||
|
role_hist: list[str] = []
|
||||||
|
multi_block_msgs = 0
|
||||||
|
msgs_with_cc = 0
|
||||||
|
sys_msgs_in_history = 0
|
||||||
|
for m in request.messages:
|
||||||
|
mtype = getattr(m, "type", type(m).__name__)
|
||||||
|
role_hist.append(mtype)
|
||||||
|
if isinstance(m, SystemMessage):
|
||||||
|
sys_msgs_in_history += 1
|
||||||
|
c = getattr(m, "content", None)
|
||||||
|
if isinstance(c, list):
|
||||||
|
multi_block_msgs += 1
|
||||||
|
for blk in c:
|
||||||
|
if isinstance(blk, dict) and "cache_control" in blk:
|
||||||
|
msgs_with_cc += 1
|
||||||
|
break
|
||||||
|
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||||
|
msgs_with_cc += 1
|
||||||
|
|
||||||
|
tools = request.tools or []
|
||||||
|
tools_with_cc = 0
|
||||||
|
for t in tools:
|
||||||
|
if isinstance(t, dict) and (
|
||||||
|
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||||
|
):
|
||||||
|
tools_with_cc += 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||||
|
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||||
|
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||||
|
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||||
|
f"roles={role_hist[-8:]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenSystemMessageMiddleware(
|
||||||
|
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||||
|
):
|
||||||
|
"""Collapse a multi-text-block system message to a single string.
|
||||||
|
|
||||||
|
Sits innermost on the system-message-mutation chain so it observes
|
||||||
|
every middleware's contribution. Has no other side effect — the
|
||||||
|
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||||
|
) -> Any:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
if flattened is not None:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(
|
||||||
|
"[flatten_system] collapsed %d system blocks to one",
|
||||||
|
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||||
|
)
|
||||||
|
return handler(flattened)
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[ContextT],
|
||||||
|
handler: Callable[
|
||||||
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||||
|
],
|
||||||
|
) -> Any:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||||
|
flattened = _flattened_request(request)
|
||||||
|
if flattened is not None:
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(
|
||||||
|
"[flatten_system] collapsed %d system blocks to one",
|
||||||
|
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||||
|
)
|
||||||
|
return await handler(flattened)
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FlattenSystemMessageMiddleware",
|
||||||
|
"_flatten_text_blocks",
|
||||||
|
"_flattened_request",
|
||||||
|
]
|
||||||
|
|
@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
|
||||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
if anon_doc:
|
if anon_doc:
|
||||||
return self._anon_priority(state, anon_doc)
|
return self._anon_priority(state, anon_doc)
|
||||||
|
|
||||||
return await self._authenticated_priority(state, messages, user_text)
|
return await self._authenticated_priority(state, messages, user_text, runtime)
|
||||||
|
|
||||||
def _anon_priority(
|
def _anon_priority(
|
||||||
self,
|
self,
|
||||||
|
|
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
|
runtime: Runtime[Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
t0 = asyncio.get_event_loop().time()
|
t0 = asyncio.get_event_loop().time()
|
||||||
(
|
(
|
||||||
|
|
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-turn ``mentioned_document_ids`` flow:
|
||||||
|
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
||||||
|
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
||||||
|
# on every ``astream_events`` call, so this list is naturally
|
||||||
|
# scoped to the current turn. Allows cross-turn graph reuse via
|
||||||
|
# ``agent_cache``.
|
||||||
|
# 2. Legacy fallback (cache disabled / context not propagated): the
|
||||||
|
# constructor-injected ``self.mentioned_document_ids`` list. We
|
||||||
|
# drain it after the first read so a cached graph (no Phase 1.5
|
||||||
|
# wiring) doesn't keep replaying the same mentions on every
|
||||||
|
# turn.
|
||||||
|
#
|
||||||
|
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
||||||
|
# all) from "context provided but empty" (turn with no mentions).
|
||||||
|
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
||||||
|
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
||||||
|
# legacy closure on every no-mention follow-up turn — replaying the
|
||||||
|
# mentions baked in by turn 1's cache-miss build. Always drain the
|
||||||
|
# closure once the runtime path has fired so a cached middleware
|
||||||
|
# instance can never resurrect stale state.
|
||||||
|
mention_ids: list[int] = []
|
||||||
|
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
||||||
|
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
||||||
|
if ctx_mentions is not None:
|
||||||
|
# Runtime path is authoritative — even an empty list means
|
||||||
|
# "this turn has no mentions", NOT "look at the closure".
|
||||||
|
mention_ids = list(ctx_mentions)
|
||||||
|
if self.mentioned_document_ids:
|
||||||
|
self.mentioned_document_ids = []
|
||||||
|
elif self.mentioned_document_ids:
|
||||||
|
mention_ids = list(self.mentioned_document_ids)
|
||||||
|
self.mentioned_document_ids = []
|
||||||
|
|
||||||
mentioned_results: list[dict[str, Any]] = []
|
mentioned_results: list[dict[str, Any]] = []
|
||||||
if self.mentioned_document_ids:
|
if mention_ids:
|
||||||
mentioned_results = await fetch_mentioned_documents(
|
mentioned_results = await fetch_mentioned_documents(
|
||||||
document_ids=self.mentioned_document_ids,
|
document_ids=mention_ids,
|
||||||
search_space_id=self.search_space_id,
|
search_space_id=self.search_space_id,
|
||||||
)
|
)
|
||||||
self.mentioned_document_ids = []
|
|
||||||
|
|
||||||
if is_recency:
|
if is_recency:
|
||||||
doc_types = _resolve_search_types(
|
doc_types = _resolve_search_types(
|
||||||
|
|
|
||||||
188
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
188
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||||
|
|
||||||
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||||
|
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||||
|
gate always failed) with LiteLLM's universal caching mechanism.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||||
|
performs automatically when ``cache_control_injection_points`` is set):
|
||||||
|
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||||
|
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||||
|
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||||
|
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||||
|
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||||
|
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
We inject **two** breakpoints per request:
|
||||||
|
|
||||||
|
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||||
|
request (provider variant, citation rules, tool catalog, KB tree,
|
||||||
|
skills metadata). The langchain agent factory always prepends
|
||||||
|
``request.system_message`` at index 0 (see ``factory.py``
|
||||||
|
``_execute_model_async``), so this targets exactly the main system
|
||||||
|
prompt regardless of how many other ``SystemMessage``\ s the
|
||||||
|
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||||
|
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||||
|
``role: system`` here would apply ``cache_control`` to **every**
|
||||||
|
system-role message and trip Anthropic's hard cap of 4 cache
|
||||||
|
breakpoints per request once the conversation accumulates enough
|
||||||
|
injected system messages — which surfaces as the upstream 400
|
||||||
|
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||||
|
via OpenRouter→Anthropic.
|
||||||
|
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||||
|
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||||
|
N+1 still reads turn N's cache up to the shared prefix.
|
||||||
|
|
||||||
|
For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
|
backend.
|
||||||
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
|
5-10 min in-memory cache.
|
||||||
|
|
||||||
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
provider doesn't recognise is auto-stripped at the provider transformer
|
||||||
|
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||||
|
``prompt_cache_key`` etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||||
|
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||||
|
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
|
#
|
||||||
|
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||||
|
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
||||||
|
# anonymous-doc) insert ``SystemMessage`` instances into
|
||||||
|
# ``state["messages"]`` that accumulate across turns. With
|
||||||
|
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
||||||
|
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
||||||
|
# always targets the langchain-prepended ``request.system_message``
|
||||||
|
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
||||||
|
# block), giving us exactly one stable cache breakpoint.
|
||||||
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
|
{"location": "message", "index": 0},
|
||||||
|
{"location": "message", "index": -1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
||||||
|
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
||||||
|
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
||||||
|
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
||||||
|
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
||||||
|
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
||||||
|
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
||||||
|
|
||||||
|
|
||||||
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||||
|
|
||||||
|
Importing ``app.services.llm_router_service`` at module-load time would
|
||||||
|
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||||
|
Class-name comparison is sufficient since the class is defined in a
|
||||||
|
single place.
|
||||||
|
"""
|
||||||
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||||
|
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
||||||
|
|
||||||
|
Strict — only returns True when the user explicitly chose OPENAI,
|
||||||
|
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
||||||
|
``YAMLConfig``. Auto-mode and custom providers return False because
|
||||||
|
we can't statically know the destination.
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||||
|
|
||||||
|
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||||
|
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||||
|
``model_kwargs`` attribute (caller should treat as no-op).
|
||||||
|
"""
|
||||||
|
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||||
|
if isinstance(model_kwargs, dict):
|
||||||
|
return model_kwargs
|
||||||
|
try:
|
||||||
|
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
refreshed = getattr(llm, "model_kwargs", None)
|
||||||
|
return refreshed if isinstance(refreshed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_litellm_prompt_caching(
|
||||||
|
llm: BaseChatModel,
|
||||||
|
*,
|
||||||
|
agent_config: AgentConfig | None = None,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||||
|
|
||||||
|
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||||
|
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||||
|
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||||
|
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||||
|
in our custom ``ChatLiteLLMRouter``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||||
|
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||||
|
behaviour. When omitted (or auto-mode), only the universal
|
||||||
|
``cache_control_injection_points`` are set.
|
||||||
|
thread_id: Optional thread id used to construct a per-thread
|
||||||
|
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||||
|
works without it (server-side automatic), but the key improves
|
||||||
|
backend routing affinity and therefore hit rate.
|
||||||
|
"""
|
||||||
|
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||||
|
if model_kwargs is None:
|
||||||
|
logger.debug(
|
||||||
|
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||||
|
type(llm).__name__,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "cache_control_injection_points" not in model_kwargs:
|
||||||
|
model_kwargs["cache_control_injection_points"] = [
|
||||||
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
|
]
|
||||||
|
|
||||||
|
# OpenAI-family extras only when we statically know the destination is
|
||||||
|
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
||||||
|
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
||||||
|
# strip them but it's wasteful to set them in the first place).
|
||||||
|
if _is_router_llm(llm):
|
||||||
|
return
|
||||||
|
if not _is_openai_family_config(agent_config):
|
||||||
|
return
|
||||||
|
|
||||||
|
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
|
if "prompt_cache_retention" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
@ -27,14 +27,9 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
||||||
NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset(
|
NON_PROVIDER_STATE_MUTATION_DENY: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
# Exact tool names from shared deny patterns.
|
# Exact tool names from shared deny patterns.
|
||||||
*{
|
*{name for name in WRITE_TOOL_DENY_PATTERNS if "*" not in name},
|
||||||
name
|
|
||||||
for name in WRITE_TOOL_DENY_PATTERNS
|
|
||||||
if "*" not in name
|
|
||||||
},
|
|
||||||
# Additional non-provider state mutation controls.
|
# Additional non-provider state mutation controls.
|
||||||
"write_todos",
|
"write_todos",
|
||||||
"task",
|
"task",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -112,10 +112,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
|
||||||
Rule(permission=name, pattern="*", action="deny")
|
Rule(permission=name, pattern="*", action="deny")
|
||||||
for name in NON_PROVIDER_STATE_MUTATION_DENY
|
for name in NON_PROVIDER_STATE_MUTATION_DENY
|
||||||
)
|
)
|
||||||
rules.extend(
|
rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
|
||||||
Rule(permission=name, pattern="*", action="ask")
|
|
||||||
for name in ask_tools
|
|
||||||
)
|
|
||||||
return PermissionMiddleware(
|
return PermissionMiddleware(
|
||||||
rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")]
|
rulesets=[Ruleset(rules=rules, origin="subagent_linear_specialist")]
|
||||||
)
|
)
|
||||||
|
|
@ -163,4 +160,3 @@ def build_linear_specialist_subagent(
|
||||||
if model is not None:
|
if model is not None:
|
||||||
spec["model"] = model
|
spec["model"] = model
|
||||||
return spec # type: ignore[return-value]
|
return spec # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,10 +119,7 @@ def _permission_middleware(*, selected_tools: Sequence[BaseTool]) -> Any:
|
||||||
Rule(permission=name, pattern="*", action="deny")
|
Rule(permission=name, pattern="*", action="deny")
|
||||||
for name in NON_PROVIDER_STATE_MUTATION_DENY
|
for name in NON_PROVIDER_STATE_MUTATION_DENY
|
||||||
)
|
)
|
||||||
rules.extend(
|
rules.extend(Rule(permission=name, pattern="*", action="ask") for name in ask_tools)
|
||||||
Rule(permission=name, pattern="*", action="ask")
|
|
||||||
for name in ask_tools
|
|
||||||
)
|
|
||||||
return PermissionMiddleware(
|
return PermissionMiddleware(
|
||||||
rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")]
|
rulesets=[Ruleset(rules=rules, origin="subagent_slack_specialist")]
|
||||||
)
|
)
|
||||||
|
|
@ -171,4 +168,3 @@ def build_slack_specialist_subagent(
|
||||||
if model is not None:
|
if model is not None:
|
||||||
spec["model"] = model
|
spec["model"] = model
|
||||||
return spec # type: ignore[return-value]
|
return spec # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_confluence_page(
|
async def create_confluence_page(
|
||||||
title: str,
|
title: str,
|
||||||
|
|
@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_confluence_page called: title='{title}'")
|
logger.info(f"create_confluence_page called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
return {"status": "error", "message": context["error"]}
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
accounts = context.get("accounts", [])
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "auth_error",
|
||||||
"message": "All connected Confluence accounts need re-authentication.",
|
"message": "All connected Confluence accounts need re-authentication.",
|
||||||
"connector_type": "confluence",
|
"connector_type": "confluence",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = request_approval(
|
result = request_approval(
|
||||||
action_type="confluence_page_creation",
|
action_type="confluence_page_creation",
|
||||||
tool_name="create_confluence_page",
|
tool_name="create_confluence_page",
|
||||||
params={
|
params={
|
||||||
"title": title,
|
"title": title,
|
||||||
"content": content,
|
"content": content,
|
||||||
"space_id": space_id,
|
"space_id": space_id,
|
||||||
"connector_id": connector_id,
|
"connector_id": connector_id,
|
||||||
},
|
},
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.rejected:
|
if result.rejected:
|
||||||
return {
|
return {
|
||||||
"status": "rejected",
|
"status": "rejected",
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
}
|
}
|
||||||
|
|
||||||
final_title = result.params.get("title", title)
|
final_title = result.params.get("title", title)
|
||||||
final_content = result.params.get("content", content) or ""
|
final_content = result.params.get("content", content) or ""
|
||||||
final_space_id = result.params.get("space_id", space_id)
|
final_space_id = result.params.get("space_id", space_id)
|
||||||
final_connector_id = result.params.get("connector_id", connector_id)
|
final_connector_id = result.params.get("connector_id", connector_id)
|
||||||
|
|
||||||
if not final_title or not final_title.strip():
|
if not final_title or not final_title.strip():
|
||||||
return {"status": "error", "message": "Page title cannot be empty."}
|
return {"status": "error", "message": "Page title cannot be empty."}
|
||||||
if not final_space_id:
|
if not final_space_id:
|
||||||
return {"status": "error", "message": "A space must be selected."}
|
return {"status": "error", "message": "A space must be selected."}
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
actual_connector_id = final_connector_id
|
actual_connector_id = final_connector_id
|
||||||
if actual_connector_id is None:
|
if actual_connector_id is None:
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connector = result.scalars().first()
|
||||||
connector = result.scalars().first()
|
if not connector:
|
||||||
if not connector:
|
return {
|
||||||
return {
|
"status": "error",
|
||||||
"status": "error",
|
"message": "No Confluence connector found.",
|
||||||
"message": "No Confluence connector found.",
|
}
|
||||||
}
|
actual_connector_id = connector.id
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == actual_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Confluence connector is invalid.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = ConfluenceHistoryConnector(
|
|
||||||
session=db_session, connector_id=actual_connector_id
|
|
||||||
)
|
|
||||||
api_result = await client.create_page(
|
|
||||||
space_id=final_space_id,
|
|
||||||
title=final_title,
|
|
||||||
body=final_content,
|
|
||||||
)
|
|
||||||
await client.close()
|
|
||||||
except Exception as api_err:
|
|
||||||
if (
|
|
||||||
"http 403" in str(api_err).lower()
|
|
||||||
or "status code 403" in str(api_err).lower()
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
_conn = connector
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
|
||||||
"status": "insufficient_permissions",
|
|
||||||
"connector_id": actual_connector_id,
|
|
||||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
|
||||||
raise
|
|
||||||
|
|
||||||
page_id = str(api_result.get("id", ""))
|
|
||||||
page_links = (
|
|
||||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
|
||||||
)
|
|
||||||
page_url = ""
|
|
||||||
if page_links.get("base") and page_links.get("webui"):
|
|
||||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
|
||||||
|
|
||||||
kb_message_suffix = ""
|
|
||||||
try:
|
|
||||||
from app.services.confluence import ConfluenceKBSyncService
|
|
||||||
|
|
||||||
kb_service = ConfluenceKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_create(
|
|
||||||
page_id=page_id,
|
|
||||||
page_title=final_title,
|
|
||||||
space_id=final_space_id,
|
|
||||||
body_content=final_content,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
result = await db_session.execute(
|
||||||
except Exception as kb_err:
|
select(SearchSourceConnector).filter(
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
SearchSourceConnector.id == actual_connector_id,
|
||||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Confluence connector is invalid.",
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
try:
|
||||||
"status": "success",
|
client = ConfluenceHistoryConnector(
|
||||||
"page_id": page_id,
|
session=db_session, connector_id=actual_connector_id
|
||||||
"page_url": page_url,
|
)
|
||||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
api_result = await client.create_page(
|
||||||
}
|
space_id=final_space_id,
|
||||||
|
title=final_title,
|
||||||
|
body=final_content,
|
||||||
|
)
|
||||||
|
await client.close()
|
||||||
|
except Exception as api_err:
|
||||||
|
if (
|
||||||
|
"http 403" in str(api_err).lower()
|
||||||
|
or "status code 403" in str(api_err).lower()
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
_conn = connector
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
page_id = str(api_result.get("id", ""))
|
||||||
|
page_links = (
|
||||||
|
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||||
|
)
|
||||||
|
page_url = ""
|
||||||
|
if page_links.get("base") and page_links.get("webui"):
|
||||||
|
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.confluence import ConfluenceKBSyncService
|
||||||
|
|
||||||
|
kb_service = ConfluenceKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
page_id=page_id,
|
||||||
|
page_title=final_title,
|
||||||
|
space_id=final_space_id,
|
||||||
|
body_content=final_content,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"page_id": page_id,
|
||||||
|
"page_url": page_url,
|
||||||
|
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_confluence_page(
|
async def delete_confluence_page(
|
||||||
page_title_or_id: str,
|
page_title_or_id: str,
|
||||||
|
|
@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
|
||||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_deletion_context(
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
search_space_id, user_id, page_title_or_id
|
context = await metadata_service.get_deletion_context(
|
||||||
)
|
search_space_id, user_id, page_title_or_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": error_msg,
|
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "confluence",
|
|
||||||
}
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
page_data = context["page"]
|
|
||||||
page_id = page_data["page_id"]
|
|
||||||
page_title = page_data.get("page_title", "")
|
|
||||||
document_id = page_data["document_id"]
|
|
||||||
connector_id_from_context = context.get("account", {}).get("id")
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="confluence_page_deletion",
|
|
||||||
tool_name="delete_confluence_page",
|
|
||||||
params={
|
|
||||||
"page_id": page_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_page_id = result.params.get("page_id", page_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this page.",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Confluence connector is invalid.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
if "error" in context:
|
||||||
client = ConfluenceHistoryConnector(
|
error_msg = context["error"]
|
||||||
session=db_session, connector_id=final_connector_id
|
if context.get("auth_expired"):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error_msg,
|
||||||
|
"connector_id": context.get("connector_id"),
|
||||||
|
"connector_type": "confluence",
|
||||||
|
}
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
page_data = context["page"]
|
||||||
|
page_id = page_data["page_id"]
|
||||||
|
page_title = page_data.get("page_title", "")
|
||||||
|
document_id = page_data["document_id"]
|
||||||
|
connector_id_from_context = context.get("account", {}).get("id")
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="confluence_page_deletion",
|
||||||
|
tool_name="delete_confluence_page",
|
||||||
|
params={
|
||||||
|
"page_id": page_id,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
"delete_from_kb": delete_from_kb,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
await client.delete_page(final_page_id)
|
|
||||||
await client.close()
|
if result.rejected:
|
||||||
except Exception as api_err:
|
|
||||||
if (
|
|
||||||
"http 403" in str(api_err).lower()
|
|
||||||
or "status code 403" in str(api_err).lower()
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "rejected",
|
||||||
"connector_id": final_connector_id,
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
deleted_from_kb = False
|
final_page_id = result.params.get("page_id", page_id)
|
||||||
if final_delete_from_kb and document_id:
|
final_connector_id = result.params.get(
|
||||||
try:
|
"connector_id", connector_id_from_context
|
||||||
from app.db import Document
|
)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
from sqlalchemy.future import select
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this page.",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
)
|
||||||
if document:
|
connector = result.scalars().first()
|
||||||
await db_session.delete(document)
|
if not connector:
|
||||||
await db_session.commit()
|
return {
|
||||||
deleted_from_kb = True
|
"status": "error",
|
||||||
except Exception as e:
|
"message": "Selected Confluence connector is invalid.",
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
}
|
||||||
await db_session.rollback()
|
|
||||||
|
|
||||||
message = f"Confluence page '{page_title}' deleted successfully."
|
try:
|
||||||
if deleted_from_kb:
|
client = ConfluenceHistoryConnector(
|
||||||
message += " Also removed from the knowledge base."
|
session=db_session, connector_id=final_connector_id
|
||||||
|
)
|
||||||
|
await client.delete_page(final_page_id)
|
||||||
|
await client.close()
|
||||||
|
except Exception as api_err:
|
||||||
|
if (
|
||||||
|
"http 403" in str(api_err).lower()
|
||||||
|
or "status code 403" in str(api_err).lower()
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": final_connector_id,
|
||||||
|
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
return {
|
deleted_from_kb = False
|
||||||
"status": "success",
|
if final_delete_from_kb and document_id:
|
||||||
"page_id": final_page_id,
|
try:
|
||||||
"deleted_from_kb": deleted_from_kb,
|
from app.db import Document
|
||||||
"message": message,
|
|
||||||
}
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
|
||||||
|
message = f"Confluence page '{page_title}' deleted successfully."
|
||||||
|
if deleted_from_kb:
|
||||||
|
message += " Also removed from the knowledge base."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"page_id": final_page_id,
|
||||||
|
"deleted_from_kb": deleted_from_kb,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.confluence import ConfluenceToolMetadataService
|
from app.services.confluence import ConfluenceToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_confluence_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_confluence_page tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_confluence_page(
|
async def update_confluence_page(
|
||||||
page_title_or_id: str,
|
page_title_or_id: str,
|
||||||
|
|
@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
|
||||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Confluence tool not properly configured.",
|
"message": "Confluence tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||||
search_space_id, user_id, page_title_or_id
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, page_title_or_id
|
||||||
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
error_msg = context["error"]
|
error_msg = context["error"]
|
||||||
if context.get("auth_expired"):
|
if context.get("auth_expired"):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error_msg,
|
||||||
|
"connector_id": context.get("connector_id"),
|
||||||
|
"connector_type": "confluence",
|
||||||
|
}
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
page_data = context["page"]
|
||||||
|
page_id = page_data["page_id"]
|
||||||
|
current_title = page_data["page_title"]
|
||||||
|
current_body = page_data.get("body", "")
|
||||||
|
current_version = page_data.get("version", 1)
|
||||||
|
document_id = page_data.get("document_id")
|
||||||
|
connector_id_from_context = context.get("account", {}).get("id")
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="confluence_page_update",
|
||||||
|
tool_name="update_confluence_page",
|
||||||
|
params={
|
||||||
|
"page_id": page_id,
|
||||||
|
"document_id": document_id,
|
||||||
|
"new_title": new_title,
|
||||||
|
"new_content": new_content,
|
||||||
|
"version": current_version,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "rejected",
|
||||||
"message": error_msg,
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "confluence",
|
|
||||||
}
|
}
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
page_data = context["page"]
|
final_page_id = result.params.get("page_id", page_id)
|
||||||
page_id = page_data["page_id"]
|
final_title = result.params.get("new_title", new_title) or current_title
|
||||||
current_title = page_data["page_title"]
|
final_content = result.params.get("new_content", new_content)
|
||||||
current_body = page_data.get("body", "")
|
if final_content is None:
|
||||||
current_version = page_data.get("version", 1)
|
final_content = current_body
|
||||||
document_id = page_data.get("document_id")
|
final_version = result.params.get("version", current_version)
|
||||||
connector_id_from_context = context.get("account", {}).get("id")
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
result = request_approval(
|
|
||||||
action_type="confluence_page_update",
|
|
||||||
tool_name="update_confluence_page",
|
|
||||||
params={
|
|
||||||
"page_id": page_id,
|
|
||||||
"document_id": document_id,
|
|
||||||
"new_title": new_title,
|
|
||||||
"new_content": new_content,
|
|
||||||
"version": current_version,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_page_id = result.params.get("page_id", page_id)
|
|
||||||
final_title = result.params.get("new_title", new_title) or current_title
|
|
||||||
final_content = result.params.get("new_content", new_content)
|
|
||||||
if final_content is None:
|
|
||||||
final_content = current_body
|
|
||||||
final_version = result.params.get("version", current_version)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_document_id = result.params.get("document_id", document_id)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this page.",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
final_document_id = result.params.get("document_id", document_id)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Confluence connector is invalid.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
from sqlalchemy.future import select
|
||||||
client = ConfluenceHistoryConnector(
|
|
||||||
session=db_session, connector_id=final_connector_id
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
)
|
|
||||||
api_result = await client.update_page(
|
if not final_connector_id:
|
||||||
page_id=final_page_id,
|
|
||||||
title=final_title,
|
|
||||||
body=final_content,
|
|
||||||
version_number=final_version + 1,
|
|
||||||
)
|
|
||||||
await client.close()
|
|
||||||
except Exception as api_err:
|
|
||||||
if (
|
|
||||||
"http 403" in str(api_err).lower()
|
|
||||||
or "status code 403" in str(api_err).lower()
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "error",
|
||||||
"connector_id": final_connector_id,
|
"message": "No connector found for this page.",
|
||||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
page_links = (
|
result = await db_session.execute(
|
||||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
select(SearchSourceConnector).filter(
|
||||||
)
|
SearchSourceConnector.id == final_connector_id,
|
||||||
page_url = ""
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
if page_links.get("base") and page_links.get("webui"):
|
SearchSourceConnector.user_id == user_id,
|
||||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||||
kb_message_suffix = ""
|
|
||||||
if final_document_id:
|
|
||||||
try:
|
|
||||||
from app.services.confluence import ConfluenceKBSyncService
|
|
||||||
|
|
||||||
kb_service = ConfluenceKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_update(
|
|
||||||
document_id=final_document_id,
|
|
||||||
page_id=final_page_id,
|
|
||||||
user_id=user_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
)
|
||||||
kb_message_suffix = (
|
connector = result.scalars().first()
|
||||||
" Your knowledge base has also been updated."
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Confluence connector is invalid.",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = ConfluenceHistoryConnector(
|
||||||
|
session=db_session, connector_id=final_connector_id
|
||||||
|
)
|
||||||
|
api_result = await client.update_page(
|
||||||
|
page_id=final_page_id,
|
||||||
|
title=final_title,
|
||||||
|
body=final_content,
|
||||||
|
version_number=final_version + 1,
|
||||||
|
)
|
||||||
|
await client.close()
|
||||||
|
except Exception as api_err:
|
||||||
|
if (
|
||||||
|
"http 403" in str(api_err).lower()
|
||||||
|
or "status code 403" in str(api_err).lower()
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": final_connector_id,
|
||||||
|
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
page_links = (
|
||||||
|
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||||
|
)
|
||||||
|
page_url = ""
|
||||||
|
if page_links.get("base") and page_links.get("webui"):
|
||||||
|
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
if final_document_id:
|
||||||
|
try:
|
||||||
|
from app.services.confluence import ConfluenceKBSyncService
|
||||||
|
|
||||||
|
kb_service = ConfluenceKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
document_id=final_document_id,
|
||||||
|
page_id=final_page_id,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
)
|
)
|
||||||
else:
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = (
|
||||||
|
" The knowledge base will be updated in the next sync."
|
||||||
|
)
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||||
kb_message_suffix = (
|
kb_message_suffix = (
|
||||||
" The knowledge base will be updated in the next sync."
|
" The knowledge base will be updated in the next sync."
|
||||||
)
|
)
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
|
||||||
kb_message_suffix = (
|
|
||||||
" The knowledge base will be updated in the next sync."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"page_id": final_page_id,
|
"page_id": final_page_id,
|
||||||
"page_url": page_url,
|
"page_url": page_url,
|
||||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
|
"""Factory function to create the get_connected_accounts tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to scope account discovery to.
|
||||||
|
user_id: User ID to scope account discovery to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured StructuredTool for connected-accounts discovery.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
async def _run(service: str) -> list[dict[str, Any]]:
|
async def _run(service: str) -> list[dict[str, Any]]:
|
||||||
svc_cfg = MCP_SERVICES.get(service)
|
svc_cfg = MCP_SERVICES.get(service)
|
||||||
|
|
@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||||
|
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
result = await db_session.execute(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.connector_type == connector_type,
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == connector_type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connectors = result.scalars().all()
|
||||||
connectors = result.scalars().all()
|
|
||||||
|
|
||||||
if not connectors:
|
if not connectors:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
|
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
is_multi = len(connectors) > 1
|
||||||
|
|
||||||
|
accounts: list[dict[str, Any]] = []
|
||||||
|
for conn in connectors:
|
||||||
|
cfg = conn.config or {}
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"connector_id": conn.id,
|
||||||
|
"display_name": _extract_display_name(conn),
|
||||||
|
"service": service,
|
||||||
}
|
}
|
||||||
]
|
if is_multi:
|
||||||
|
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||||
|
for key in svc_cfg.account_metadata_keys:
|
||||||
|
if key in cfg:
|
||||||
|
entry[key] = cfg[key]
|
||||||
|
accounts.append(entry)
|
||||||
|
|
||||||
is_multi = len(connectors) > 1
|
return accounts
|
||||||
|
|
||||||
accounts: list[dict[str, Any]] = []
|
|
||||||
for conn in connectors:
|
|
||||||
cfg = conn.config or {}
|
|
||||||
entry: dict[str, Any] = {
|
|
||||||
"connector_id": conn.id,
|
|
||||||
"display_name": _extract_display_name(conn),
|
|
||||||
"service": service,
|
|
||||||
}
|
|
||||||
if is_multi:
|
|
||||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
|
||||||
for key in svc_cfg.account_metadata_keys:
|
|
||||||
if key in cfg:
|
|
||||||
entry[key] = cfg[key]
|
|
||||||
accounts.append(entry)
|
|
||||||
|
|
||||||
return accounts
|
|
||||||
|
|
||||||
return StructuredTool(
|
return StructuredTool(
|
||||||
name="get_connected_accounts",
|
name="get_connected_accounts",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_discord_channels tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_discord_channels tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_discord_channels() -> dict[str, Any]:
|
async def list_discord_channels() -> dict[str, Any]:
|
||||||
"""List text channels in the connected Discord server.
|
"""List text channels in the connected Discord server.
|
||||||
|
|
@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with status and a list of channels (id, name).
|
Dictionary with status and a list of channels (id, name).
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_discord_connector(
|
async with async_session_maker() as db_session:
|
||||||
db_session, search_space_id, user_id
|
connector = await get_discord_connector(
|
||||||
)
|
db_session, search_space_id, user_id
|
||||||
if not connector:
|
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
|
||||||
|
|
||||||
guild_id = get_guild_id(connector)
|
|
||||||
if not guild_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No guild ID in Discord connector config.",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
|
||||||
headers={"Authorization": f"Bot {token}"},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
if resp.status_code == 401:
|
guild_id = get_guild_id(connector)
|
||||||
return {
|
if not guild_id:
|
||||||
"status": "auth_error",
|
return {
|
||||||
"message": "Discord bot token is invalid.",
|
"status": "error",
|
||||||
"connector_type": "discord",
|
"message": "No guild ID in Discord connector config.",
|
||||||
}
|
}
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Type 0 = text channel
|
token = get_bot_token(connector)
|
||||||
channels = [
|
|
||||||
{"id": ch["id"], "name": ch["name"]}
|
async with httpx.AsyncClient() as client:
|
||||||
for ch in resp.json()
|
resp = await client.get(
|
||||||
if ch.get("type") == 0
|
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||||
]
|
headers={"Authorization": f"Bot {token}"},
|
||||||
return {
|
timeout=15.0,
|
||||||
"status": "success",
|
)
|
||||||
"guild_id": guild_id,
|
|
||||||
"channels": channels,
|
if resp.status_code == 401:
|
||||||
"total": len(channels),
|
return {
|
||||||
}
|
"status": "auth_error",
|
||||||
|
"message": "Discord bot token is invalid.",
|
||||||
|
"connector_type": "discord",
|
||||||
|
}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Discord API error: {resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Type 0 = text channel
|
||||||
|
channels = [
|
||||||
|
{"id": ch["id"], "name": ch["name"]}
|
||||||
|
for ch in resp.json()
|
||||||
|
if ch.get("type") == 0
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"channels": channels,
|
||||||
|
"total": len(channels),
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_discord_messages tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_discord_messages tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_discord_messages(
|
async def read_discord_messages(
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
|
|
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
|
||||||
Dictionary with status and a list of messages including
|
Dictionary with status and a list of messages including
|
||||||
id, author, content, timestamp.
|
id, author, content, timestamp.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
|
|
@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
|
||||||
limit = min(limit, 50)
|
limit = min(limit, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_discord_connector(
|
async with async_session_maker() as db_session:
|
||||||
db_session, search_space_id, user_id
|
connector = await get_discord_connector(
|
||||||
)
|
db_session, search_space_id, user_id
|
||||||
if not connector:
|
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
|
||||||
headers={"Authorization": f"Bot {token}"},
|
|
||||||
params={"limit": limit},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
if resp.status_code == 401:
|
token = get_bot_token(connector)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Discord bot token is invalid.",
|
|
||||||
"connector_type": "discord",
|
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Bot lacks permission to read this channel.",
|
|
||||||
}
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [
|
async with httpx.AsyncClient() as client:
|
||||||
{
|
resp = await client.get(
|
||||||
"id": m["id"],
|
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||||
"author": m.get("author", {}).get("username", "Unknown"),
|
headers={"Authorization": f"Bot {token}"},
|
||||||
"content": m.get("content", ""),
|
params={"limit": limit},
|
||||||
"timestamp": m.get("timestamp", ""),
|
timeout=15.0,
|
||||||
}
|
)
|
||||||
for m in resp.json()
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
if resp.status_code == 401:
|
||||||
"status": "success",
|
return {
|
||||||
"channel_id": channel_id,
|
"status": "auth_error",
|
||||||
"messages": messages,
|
"message": "Discord bot token is invalid.",
|
||||||
"total": len(messages),
|
"connector_type": "discord",
|
||||||
}
|
}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Bot lacks permission to read this channel.",
|
||||||
|
}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Discord API error: {resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"id": m["id"],
|
||||||
|
"author": m.get("author", {}).get("username", "Unknown"),
|
||||||
|
"content": m.get("content", ""),
|
||||||
|
"timestamp": m.get("timestamp", ""),
|
||||||
|
}
|
||||||
|
for m in resp.json()
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"messages": messages,
|
||||||
|
"total": len(messages),
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_discord_message tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_discord_message tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_discord_message(
|
async def send_discord_message(
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
|
|
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Discord tool not properly configured.",
|
"message": "Discord tool not properly configured.",
|
||||||
|
|
@ -47,64 +65,65 @@ def create_send_discord_message_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_discord_connector(
|
async with async_session_maker() as db_session:
|
||||||
db_session, search_space_id, user_id
|
connector = await get_discord_connector(
|
||||||
)
|
db_session, search_space_id, user_id
|
||||||
if not connector:
|
)
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
result = request_approval(
|
result = request_approval(
|
||||||
action_type="discord_send_message",
|
action_type="discord_send_message",
|
||||||
tool_name="send_discord_message",
|
tool_name="send_discord_message",
|
||||||
params={"channel_id": channel_id, "content": content},
|
params={"channel_id": channel_id, "content": content},
|
||||||
context={"connector_id": connector.id},
|
context={"connector_id": connector.id},
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Message was not sent.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_channel = result.params.get("channel_id", channel_id)
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bot {token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={"content": final_content},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if result.rejected:
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "rejected",
|
||||||
"message": "Discord bot token is invalid.",
|
"message": "User declined. Message was not sent.",
|
||||||
"connector_type": "discord",
|
}
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Bot lacks permission to send messages in this channel.",
|
|
||||||
}
|
|
||||||
if resp.status_code not in (200, 201):
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
msg_data = resp.json()
|
final_content = result.params.get("content", content)
|
||||||
return {
|
final_channel = result.params.get("channel_id", channel_id)
|
||||||
"status": "success",
|
|
||||||
"message_id": msg_data.get("id"),
|
token = get_bot_token(connector)
|
||||||
"message": f"Message sent to channel {final_channel}.",
|
|
||||||
}
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"content": final_content},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "Discord bot token is invalid.",
|
||||||
|
"connector_type": "discord",
|
||||||
|
}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Bot lacks permission to send messages in this channel.",
|
||||||
|
}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Discord API error: {resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
msg_data = resp.json()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": msg_data.get("id"),
|
||||||
|
"message": f"Message sent to channel {final_channel}.",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.dropbox.client import DropboxClient
|
from app.connectors.dropbox.client import DropboxClient
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_dropbox_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_dropbox_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_dropbox_file(
|
async def create_dropbox_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
|
||||||
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Dropbox tool not properly configured.",
|
"message": "Dropbox tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connectors = result.scalars().all()
|
|
||||||
|
|
||||||
if not connectors:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts = []
|
|
||||||
for c in connectors:
|
|
||||||
cfg = c.config or {}
|
|
||||||
accounts.append(
|
|
||||||
{
|
|
||||||
"id": c.id,
|
|
||||||
"name": c.name,
|
|
||||||
"user_email": cfg.get("user_email"),
|
|
||||||
"auth_expired": cfg.get("auth_expired", False),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if all(a.get("auth_expired") for a in accounts):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Dropbox accounts need re-authentication.",
|
|
||||||
"connector_type": "dropbox",
|
|
||||||
}
|
|
||||||
|
|
||||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
|
||||||
for acc in accounts:
|
|
||||||
cid = acc["id"]
|
|
||||||
if acc.get("auth_expired"):
|
|
||||||
parent_folders[cid] = []
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
client = DropboxClient(session=db_session, connector_id=cid)
|
|
||||||
items, err = await client.list_folder("")
|
|
||||||
if err:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to list folders for connector %s: %s", cid, err
|
|
||||||
)
|
|
||||||
parent_folders[cid] = []
|
|
||||||
else:
|
|
||||||
parent_folders[cid] = [
|
|
||||||
{
|
|
||||||
"folder_path": item.get("path_lower", ""),
|
|
||||||
"name": item["name"],
|
|
||||||
}
|
|
||||||
for item in items
|
|
||||||
if item.get(".tag") == "folder" and item.get("name")
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Error fetching folders for connector %s", cid, exc_info=True
|
|
||||||
)
|
|
||||||
parent_folders[cid] = []
|
|
||||||
|
|
||||||
context: dict[str, Any] = {
|
|
||||||
"accounts": accounts,
|
|
||||||
"parent_folders": parent_folders,
|
|
||||||
"supported_types": _SUPPORTED_TYPES,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="dropbox_file_creation",
|
|
||||||
tool_name="create_dropbox_file",
|
|
||||||
params={
|
|
||||||
"name": name,
|
|
||||||
"file_type": file_type,
|
|
||||||
"content": content,
|
|
||||||
"connector_id": None,
|
|
||||||
"parent_folder_path": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_name = result.params.get("name", name)
|
|
||||||
final_file_type = result.params.get("file_type", file_type)
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
final_parent_folder_path = result.params.get("parent_folder_path")
|
|
||||||
|
|
||||||
if not final_name or not final_name.strip():
|
|
||||||
return {"status": "error", "message": "File name cannot be empty."}
|
|
||||||
|
|
||||||
final_name = _ensure_extension(final_name, final_file_type)
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connectors = result.scalars().all()
|
||||||
else:
|
|
||||||
connector = connectors[0]
|
|
||||||
|
|
||||||
if not connector:
|
if not connectors:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Dropbox connector is invalid.",
|
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts = []
|
||||||
|
for c in connectors:
|
||||||
|
cfg = c.config or {}
|
||||||
|
accounts.append(
|
||||||
|
{
|
||||||
|
"id": c.id,
|
||||||
|
"name": c.name,
|
||||||
|
"user_email": cfg.get("user_email"),
|
||||||
|
"auth_expired": cfg.get("auth_expired", False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if all(a.get("auth_expired") for a in accounts):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "All connected Dropbox accounts need re-authentication.",
|
||||||
|
"connector_type": "dropbox",
|
||||||
|
}
|
||||||
|
|
||||||
|
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||||
|
for acc in accounts:
|
||||||
|
cid = acc["id"]
|
||||||
|
if acc.get("auth_expired"):
|
||||||
|
parent_folders[cid] = []
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
client = DropboxClient(session=db_session, connector_id=cid)
|
||||||
|
items, err = await client.list_folder("")
|
||||||
|
if err:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to list folders for connector %s: %s", cid, err
|
||||||
|
)
|
||||||
|
parent_folders[cid] = []
|
||||||
|
else:
|
||||||
|
parent_folders[cid] = [
|
||||||
|
{
|
||||||
|
"folder_path": item.get("path_lower", ""),
|
||||||
|
"name": item["name"],
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
if item.get(".tag") == "folder" and item.get("name")
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Error fetching folders for connector %s",
|
||||||
|
cid,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
parent_folders[cid] = []
|
||||||
|
|
||||||
|
context: dict[str, Any] = {
|
||||||
|
"accounts": accounts,
|
||||||
|
"parent_folders": parent_folders,
|
||||||
|
"supported_types": _SUPPORTED_TYPES,
|
||||||
}
|
}
|
||||||
|
|
||||||
client = DropboxClient(session=db_session, connector_id=connector.id)
|
result = request_approval(
|
||||||
|
action_type="dropbox_file_creation",
|
||||||
parent_path = final_parent_folder_path or ""
|
tool_name="create_dropbox_file",
|
||||||
file_path = (
|
params={
|
||||||
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
"name": name,
|
||||||
)
|
"file_type": file_type,
|
||||||
|
"content": content,
|
||||||
if final_file_type == "paper":
|
"connector_id": None,
|
||||||
created = await client.create_paper_doc(file_path, final_content or "")
|
"parent_folder_path": None,
|
||||||
file_id = created.get("file_id", "")
|
},
|
||||||
web_url = created.get("url", "")
|
context=context,
|
||||||
else:
|
|
||||||
docx_bytes = _markdown_to_docx(final_content or "")
|
|
||||||
created = await client.upload_file(
|
|
||||||
file_path, docx_bytes, mode="add", autorename=True
|
|
||||||
)
|
)
|
||||||
file_id = created.get("id", "")
|
|
||||||
web_url = ""
|
|
||||||
|
|
||||||
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
kb_message_suffix = ""
|
final_name = result.params.get("name", name)
|
||||||
try:
|
final_file_type = result.params.get("file_type", file_type)
|
||||||
from app.services.dropbox import DropboxKBSyncService
|
final_content = result.params.get("content", content)
|
||||||
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
final_parent_folder_path = result.params.get("parent_folder_path")
|
||||||
|
|
||||||
kb_service = DropboxKBSyncService(db_session)
|
if not final_name or not final_name.strip():
|
||||||
kb_result = await kb_service.sync_after_create(
|
return {"status": "error", "message": "File name cannot be empty."}
|
||||||
file_id=file_id,
|
|
||||||
file_name=final_name,
|
final_name = _ensure_extension(final_name, final_file_type)
|
||||||
file_path=file_path,
|
|
||||||
web_url=web_url,
|
if final_connector_id is not None:
|
||||||
content=final_content,
|
result = await db_session.execute(
|
||||||
connector_id=connector.id,
|
select(SearchSourceConnector).filter(
|
||||||
search_space_id=search_space_id,
|
SearchSourceConnector.id == final_connector_id,
|
||||||
user_id=user_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
)
|
SearchSourceConnector.user_id == user_id,
|
||||||
if kb_result["status"] == "success":
|
SearchSourceConnector.connector_type
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
connector = connectors[0]
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
if not connector:
|
||||||
"status": "success",
|
return {
|
||||||
"file_id": file_id,
|
"status": "error",
|
||||||
"name": final_name,
|
"message": "Selected Dropbox connector is invalid.",
|
||||||
"web_url": web_url,
|
}
|
||||||
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
|
||||||
}
|
client = DropboxClient(session=db_session, connector_id=connector.id)
|
||||||
|
|
||||||
|
parent_path = final_parent_folder_path or ""
|
||||||
|
file_path = (
|
||||||
|
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_file_type == "paper":
|
||||||
|
created = await client.create_paper_doc(
|
||||||
|
file_path, final_content or ""
|
||||||
|
)
|
||||||
|
file_id = created.get("file_id", "")
|
||||||
|
web_url = created.get("url", "")
|
||||||
|
else:
|
||||||
|
docx_bytes = _markdown_to_docx(final_content or "")
|
||||||
|
created = await client.upload_file(
|
||||||
|
file_path, docx_bytes, mode="add", autorename=True
|
||||||
|
)
|
||||||
|
file_id = created.get("id", "")
|
||||||
|
web_url = ""
|
||||||
|
|
||||||
|
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.dropbox import DropboxKBSyncService
|
||||||
|
|
||||||
|
kb_service = DropboxKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
file_id=file_id,
|
||||||
|
file_name=final_name,
|
||||||
|
file_path=file_path,
|
||||||
|
web_url=web_url,
|
||||||
|
content=final_content,
|
||||||
|
connector_id=connector.id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": file_id,
|
||||||
|
"name": final_name,
|
||||||
|
"web_url": web_url,
|
||||||
|
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
||||||
DocumentType,
|
DocumentType,
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
|
async_session_maker,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_dropbox_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_dropbox_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_dropbox_file(
|
async def delete_dropbox_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
|
||||||
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Dropbox tool not properly configured.",
|
"message": "Dropbox tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc_result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(Document)
|
|
||||||
.join(
|
|
||||||
SearchSourceConnector,
|
|
||||||
Document.connector_id == SearchSourceConnector.id,
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
and_(
|
|
||||||
Document.search_space_id == search_space_id,
|
|
||||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
|
||||||
func.lower(Document.title) == func.lower(file_name),
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(Document.updated_at.desc().nullslast())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
document = doc_result.scalars().first()
|
|
||||||
|
|
||||||
if not document:
|
|
||||||
doc_result = await db_session.execute(
|
doc_result = await db_session.execute(
|
||||||
select(Document)
|
select(Document)
|
||||||
.join(
|
.join(
|
||||||
|
|
@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
|
||||||
and_(
|
and_(
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||||
func.lower(
|
func.lower(Document.title) == func.lower(file_name),
|
||||||
cast(
|
|
||||||
Document.document_metadata["dropbox_file_name"],
|
|
||||||
String,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
== func.lower(file_name),
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
document = doc_result.scalars().first()
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
return {
|
doc_result = await db_session.execute(
|
||||||
"status": "not_found",
|
select(Document)
|
||||||
"message": (
|
.join(
|
||||||
f"File '{file_name}' not found in your indexed Dropbox files. "
|
SearchSourceConnector,
|
||||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
Document.connector_id == SearchSourceConnector.id,
|
||||||
"or (3) the file name is different."
|
)
|
||||||
),
|
.filter(
|
||||||
}
|
and_(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
if not document.connector_id:
|
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||||
return {
|
func.lower(
|
||||||
"status": "error",
|
cast(
|
||||||
"message": "Document has no associated connector.",
|
Document.document_metadata["dropbox_file_name"],
|
||||||
}
|
String,
|
||||||
|
)
|
||||||
meta = document.document_metadata or {}
|
)
|
||||||
file_path = meta.get("dropbox_path")
|
== func.lower(file_name),
|
||||||
file_id = meta.get("dropbox_file_id")
|
SearchSourceConnector.user_id == user_id,
|
||||||
document_id = document.id
|
)
|
||||||
|
)
|
||||||
if not file_path:
|
.order_by(Document.updated_at.desc().nullslast())
|
||||||
return {
|
.limit(1)
|
||||||
"status": "error",
|
|
||||||
"message": "File path is missing. Please re-index the file.",
|
|
||||||
}
|
|
||||||
|
|
||||||
conn_result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
and_(
|
|
||||||
SearchSourceConnector.id == document.connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
document = doc_result.scalars().first()
|
||||||
)
|
|
||||||
connector = conn_result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Dropbox connector not found or access denied.",
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg = connector.config or {}
|
if not document:
|
||||||
if cfg.get("auth_expired"):
|
return {
|
||||||
return {
|
"status": "not_found",
|
||||||
"status": "auth_error",
|
"message": (
|
||||||
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
f"File '{file_name}' not found in your indexed Dropbox files. "
|
||||||
"connector_type": "dropbox",
|
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||||
}
|
"or (3) the file name is different."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
context = {
|
if not document.connector_id:
|
||||||
"file": {
|
return {
|
||||||
"file_id": file_id,
|
"status": "error",
|
||||||
"file_path": file_path,
|
"message": "Document has no associated connector.",
|
||||||
"name": file_name,
|
}
|
||||||
"document_id": document_id,
|
|
||||||
},
|
|
||||||
"account": {
|
|
||||||
"id": connector.id,
|
|
||||||
"name": connector.name,
|
|
||||||
"user_email": cfg.get("user_email"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = request_approval(
|
meta = document.document_metadata or {}
|
||||||
action_type="dropbox_file_trash",
|
file_path = meta.get("dropbox_path")
|
||||||
tool_name="delete_dropbox_file",
|
file_id = meta.get("dropbox_file_id")
|
||||||
params={
|
document_id = document.id
|
||||||
"file_path": file_path,
|
|
||||||
"connector_id": connector.id,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
if not file_path:
|
||||||
return {
|
return {
|
||||||
"status": "rejected",
|
"status": "error",
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
"message": "File path is missing. Please re-index the file.",
|
||||||
}
|
}
|
||||||
|
|
||||||
final_file_path = result.params.get("file_path", file_path)
|
conn_result = await db_session.execute(
|
||||||
final_connector_id = result.params.get("connector_id", connector.id)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
if final_connector_id != connector.id:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
and_(
|
and_(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
SearchSourceConnector.id == document.connector_id,
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
|
|
@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
validated_connector = result.scalars().first()
|
connector = conn_result.scalars().first()
|
||||||
if not validated_connector:
|
if not connector:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
"message": "Dropbox connector not found or access denied.",
|
||||||
}
|
}
|
||||||
actual_connector_id = validated_connector.id
|
|
||||||
else:
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
cfg = connector.config or {}
|
||||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
if cfg.get("auth_expired"):
|
||||||
)
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "dropbox",
|
||||||
|
}
|
||||||
|
|
||||||
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
|
context = {
|
||||||
await client.delete_file(final_file_path)
|
"file": {
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_path": file_path,
|
||||||
|
"name": file_name,
|
||||||
|
"document_id": document_id,
|
||||||
|
},
|
||||||
|
"account": {
|
||||||
|
"id": connector.id,
|
||||||
|
"name": connector.name,
|
||||||
|
"user_email": cfg.get("user_email"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
result = request_approval(
|
||||||
|
action_type="dropbox_file_trash",
|
||||||
trash_result: dict[str, Any] = {
|
tool_name="delete_dropbox_file",
|
||||||
"status": "success",
|
params={
|
||||||
"file_id": file_id,
|
"file_path": file_path,
|
||||||
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
"connector_id": connector.id,
|
||||||
}
|
"delete_from_kb": delete_from_kb,
|
||||||
|
},
|
||||||
deleted_from_kb = False
|
context=context,
|
||||||
if final_delete_from_kb and document_id:
|
|
||||||
try:
|
|
||||||
doc_result = await db_session.execute(
|
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
)
|
|
||||||
doc = doc_result.scalars().first()
|
|
||||||
if doc:
|
|
||||||
await db_session.delete(doc)
|
|
||||||
await db_session.commit()
|
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
trash_result["warning"] = (
|
|
||||||
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
|
||||||
if deleted_from_kb:
|
|
||||||
trash_result["message"] = (
|
|
||||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return trash_result
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_file_path = result.params.get("file_path", file_path)
|
||||||
|
final_connector_id = result.params.get("connector_id", connector.id)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_connector_id != connector.id:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
and_(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id
|
||||||
|
== search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
validated_connector = result.scalars().first()
|
||||||
|
if not validated_connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = validated_connector.id
|
||||||
|
else:
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = DropboxClient(
|
||||||
|
session=db_session, connector_id=actual_connector_id
|
||||||
|
)
|
||||||
|
await client.delete_file(final_file_path)
|
||||||
|
|
||||||
|
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||||
|
|
||||||
|
trash_result: dict[str, Any] = {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": file_id,
|
||||||
|
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_from_kb = False
|
||||||
|
if final_delete_from_kb and document_id:
|
||||||
|
try:
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
doc = doc_result.scalars().first()
|
||||||
|
if doc:
|
||||||
|
await db_session.delete(doc)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
trash_result["warning"] = (
|
||||||
|
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
trash_result["message"] = (
|
||||||
|
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return trash_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
is_image_gen_auto_mode,
|
is_image_gen_auto_mode,
|
||||||
)
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.utils.signed_image_urls import generate_image_token
|
from app.utils.signed_image_urls import generate_image_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||||
|
if custom_provider:
|
||||||
|
return custom_provider
|
||||||
|
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||||
|
|
||||||
|
|
||||||
def _build_model_string(
|
def _build_model_string(
|
||||||
provider: str, model_name: str, custom_provider: str | None
|
provider: str, model_name: str, custom_provider: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
if custom_provider:
|
prefix = _resolve_provider_prefix(provider, custom_provider)
|
||||||
return f"{custom_provider}/{model_name}"
|
|
||||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
||||||
return f"{prefix}/{model_name}"
|
return f"{prefix}/{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -146,14 +151,18 @@ def create_generate_image_tool(
|
||||||
"error": f"Image generation config {config_id} not found"
|
"error": f"Image generation config {config_id} not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
cfg.get("provider", ""),
|
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||||
cfg["model_name"],
|
|
||||||
cfg.get("custom_provider"),
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||||
if cfg.get("api_base"):
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = cfg["api_base"]
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=cfg.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if cfg.get("api_version"):
|
if cfg.get("api_version"):
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
gen_kwargs["api_version"] = cfg["api_version"]
|
||||||
if cfg.get("litellm_params"):
|
if cfg.get("litellm_params"):
|
||||||
|
|
@ -175,14 +184,18 @@ def create_generate_image_tool(
|
||||||
"error": f"Image generation config {config_id} not found"
|
"error": f"Image generation config {config_id} not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
db_cfg.provider.value,
|
db_cfg.provider.value, db_cfg.custom_provider
|
||||||
db_cfg.model_name,
|
|
||||||
db_cfg.custom_provider,
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
gen_kwargs["api_key"] = db_cfg.api_key
|
||||||
if db_cfg.api_base:
|
api_base = resolve_api_base(
|
||||||
gen_kwargs["api_base"] = db_cfg.api_base
|
provider=db_cfg.provider.value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=db_cfg.api_base,
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if db_cfg.api_version:
|
if db_cfg.api_version:
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
gen_kwargs["api_version"] = db_cfg.api_version
|
||||||
if db_cfg.litellm_params:
|
if db_cfg.litellm_params:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
|
||||||
|
def split_recipients(value: str | None) -> list[str]:
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_composio_data(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
inner = data.get("data", data)
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
return inner.get("response_data", inner)
|
||||||
|
return inner
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_composio_gmail_tool(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Any, str | None]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return None, "Composio connected account ID not found for this Gmail connector."
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
return None, result.get("error", "Unknown Composio Gmail error")
|
||||||
|
|
||||||
|
return unwrap_composio_data(result.get("data")), None
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_gmail_draft tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_gmail_draft tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_gmail_draft(
|
async def create_gmail_draft(
|
||||||
to: str,
|
to: str,
|
||||||
|
|
@ -57,246 +75,276 @@ def create_create_gmail_draft_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
|
||||||
return {"status": "error", "message": context["error"]}
|
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
logger.warning("All Gmail accounts have expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="gmail_draft_creation",
|
|
||||||
tool_name="create_gmail_draft",
|
|
||||||
params={
|
|
||||||
"to": to,
|
|
||||||
"subject": subject,
|
|
||||||
"body": body,
|
|
||||||
"cc": cc,
|
|
||||||
"bcc": bcc,
|
|
||||||
"connector_id": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_to = result.params.get("to", to)
|
|
||||||
final_subject = result.params.get("subject", subject)
|
|
||||||
final_body = result.params.get("body", body)
|
|
||||||
final_cc = result.params.get("cc", cc)
|
|
||||||
final_bcc = result.params.get("bcc", bcc)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_gmail_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
return {
|
logger.error(
|
||||||
"status": "error",
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
)
|
||||||
)
|
return {"status": "error", "message": context["error"]}
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
accounts = context.get("accounts", [])
|
||||||
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
|
logger.warning("All Gmail accounts have expired authentication")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "auth_error",
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "gmail",
|
||||||
}
|
}
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||||
)
|
)
|
||||||
|
result = request_approval(
|
||||||
|
action_type="gmail_draft_creation",
|
||||||
|
tool_name="create_gmail_draft",
|
||||||
|
params={
|
||||||
|
"to": to,
|
||||||
|
"subject": subject,
|
||||||
|
"body": body,
|
||||||
|
"cc": cc,
|
||||||
|
"bcc": bcc,
|
||||||
|
"connector_id": None,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if result.rejected:
|
||||||
connector.connector_type
|
return {
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
"status": "rejected",
|
||||||
):
|
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
}
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
final_to = result.params.get("to", to)
|
||||||
if cca_id:
|
final_subject = result.params.get("subject", subject)
|
||||||
creds = build_composio_credentials(cca_id)
|
final_body = result.params.get("body", body)
|
||||||
|
final_cc = result.params.get("cc", cc)
|
||||||
|
final_bcc = result.params.get("bcc", bcc)
|
||||||
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_gmail_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
if final_connector_id is not None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
else:
|
else:
|
||||||
return {
|
result = await db_session.execute(
|
||||||
"status": "error",
|
select(SearchSourceConnector).filter(
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
}
|
SearchSourceConnector.user_id == user_id,
|
||||||
else:
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
from google.oauth2.credentials import Credentials
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and config.SECRET_KEY:
|
|
||||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
||||||
if config_data.get("token"):
|
|
||||||
config_data["token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["token"]
|
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["refresh_token"]
|
|
||||||
)
|
|
||||||
if config_data.get("client_secret"):
|
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
||||||
config_data["client_secret"]
|
|
||||||
)
|
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
|
||||||
message["to"] = final_to
|
|
||||||
message["subject"] = final_subject
|
|
||||||
if final_cc:
|
|
||||||
message["cc"] = final_cc
|
|
||||||
if final_bcc:
|
|
||||||
message["bcc"] = final_bcc
|
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
|
||||||
|
|
||||||
try:
|
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
gmail_service.users()
|
|
||||||
.drafts()
|
|
||||||
.create(userId="me", body={"message": {"raw": raw}})
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
|
||||||
)
|
)
|
||||||
try:
|
connector = result.scalars().first()
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
_res = await db_session.execute(
|
logger.info(
|
||||||
select(SearchSourceConnector).where(
|
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
SearchSourceConnector.id == actual_connector_id
|
)
|
||||||
|
|
||||||
|
is_composio_gmail = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_gmail:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if config_data.get("token"):
|
||||||
|
config_data["token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["token"]
|
||||||
)
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config_data.get("client_secret"):
|
||||||
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["client_secret"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MIMEText(final_body)
|
||||||
|
message["to"] = final_to
|
||||||
|
message["subject"] = final_subject
|
||||||
|
if final_cc:
|
||||||
|
message["cc"] = final_cc
|
||||||
|
if final_bcc:
|
||||||
|
message["bcc"] = final_bcc
|
||||||
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
)
|
)
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
created, error = await execute_composio_gmail_tool(
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
connector,
|
||||||
flag_modified(_conn, "config")
|
user_id,
|
||||||
await db_session.commit()
|
"GMAIL_CREATE_EMAIL_DRAFT",
|
||||||
except Exception:
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.drafts()
|
||||||
|
.create(userId="me", body={"message": {"raw": raw}})
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to persist auth_expired for connector %s",
|
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
return {
|
try:
|
||||||
"status": "insufficient_permissions",
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
"connector_id": actual_connector_id,
|
|
||||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
_res = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).where(
|
||||||
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
kb_message_suffix = ""
|
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||||
try:
|
|
||||||
from app.services.gmail import GmailKBSyncService
|
|
||||||
|
|
||||||
kb_service = GmailKBSyncService(db_session)
|
kb_message_suffix = ""
|
||||||
draft_message = created.get("message", {})
|
try:
|
||||||
kb_result = await kb_service.sync_after_create(
|
from app.services.gmail import GmailKBSyncService
|
||||||
message_id=draft_message.get("id", ""),
|
|
||||||
thread_id=draft_message.get("threadId", ""),
|
kb_service = GmailKBSyncService(db_session)
|
||||||
subject=final_subject,
|
draft_message = created.get("message", {})
|
||||||
sender="me",
|
kb_result = await kb_service.sync_after_create(
|
||||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
message_id=draft_message.get("id", ""),
|
||||||
body_text=final_body,
|
thread_id=draft_message.get("threadId", ""),
|
||||||
connector_id=actual_connector_id,
|
subject=final_subject,
|
||||||
search_space_id=search_space_id,
|
sender="me",
|
||||||
user_id=user_id,
|
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
draft_id=created.get("id"),
|
body_text=final_body,
|
||||||
)
|
connector_id=actual_connector_id,
|
||||||
if kb_result["status"] == "success":
|
search_space_id=search_space_id,
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
user_id=user_id,
|
||||||
else:
|
draft_id=created.get("id"),
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"draft_id": created.get("id"),
|
"draft_id": created.get("id"),
|
||||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||||
"""Read the full content of a specific Gmail email by its message ID.
|
"""Read the full content of a specific Gmail email by its message ID.
|
||||||
|
|
@ -32,60 +49,115 @@ def create_read_gmail_email_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with status and the full email content formatted as markdown.
|
Dictionary with status and the full email content formatted as markdown.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
result = await db_session.execute(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connector = result.scalars().first()
|
||||||
connector = result.scalars().first()
|
if not connector:
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
|
||||||
|
|
||||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
|
||||||
|
|
||||||
gmail = GoogleGmailConnector(
|
|
||||||
credentials=creds,
|
|
||||||
session=db_session,
|
|
||||||
user_id=user_id,
|
|
||||||
connector_id=connector.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
detail, error = await gmail.get_message_details(message_id)
|
|
||||||
if error:
|
|
||||||
if (
|
|
||||||
"re-authenticate" in error.lower()
|
|
||||||
or "authentication failed" in error.lower()
|
|
||||||
):
|
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "error",
|
||||||
"message": error,
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
}
|
||||||
return {"status": "error", "message": error}
|
|
||||||
|
|
||||||
if not detail:
|
if (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
):
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
_format_gmail_summary,
|
||||||
|
)
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
detail, error = await service.get_gmail_message_detail(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
if not detail:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"message": f"Email with ID '{message_id}' not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
summary = _format_gmail_summary(detail)
|
||||||
|
content = (
|
||||||
|
f"# {summary['subject']}\n\n"
|
||||||
|
f"**From:** {summary['from']}\n"
|
||||||
|
f"**To:** {summary['to']}\n"
|
||||||
|
f"**Date:** {summary['date']}\n\n"
|
||||||
|
f"## Message Content\n\n"
|
||||||
|
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||||
|
f"## Message Details\n\n"
|
||||||
|
f"- **Message ID:** {summary['message_id']}\n"
|
||||||
|
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": summary["message_id"] or message_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||||
|
_build_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
|
|
||||||
|
gmail = GoogleGmailConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
detail, error = await gmail.get_message_details(message_id)
|
||||||
|
if error:
|
||||||
|
if (
|
||||||
|
"re-authenticate" in error.lower()
|
||||||
|
or "authentication failed" in error.lower()
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error,
|
||||||
|
"connector_type": "gmail",
|
||||||
|
}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
if not detail:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"message": f"Email with ID '{message_id}' not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
content = gmail.format_message_to_markdown(detail)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "not_found",
|
"status": "success",
|
||||||
"message": f"Email with ID '{message_id}' not found.",
|
"message_id": message_id,
|
||||||
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
content = gmail.format_message_to_markdown(detail)
|
|
||||||
|
|
||||||
return {"status": "success", "message_id": message_id, "content": content}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||||
|
|
||||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
raise ValueError("Composio connectors must use Composio tool execution.")
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if not cca_id:
|
|
||||||
raise ValueError("Composio connected account ID not found.")
|
|
||||||
return build_composio_credentials(cca_id)
|
|
||||||
|
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
|
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
|
||||||
|
headers = message.get("payload", {}).get("headers", [])
|
||||||
|
return {
|
||||||
|
header.get("name", "").lower(): header.get("value", "")
|
||||||
|
for header in headers
|
||||||
|
if isinstance(header, dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
headers = _gmail_headers(message)
|
||||||
|
return {
|
||||||
|
"message_id": message.get("id") or message.get("messageId"),
|
||||||
|
"thread_id": message.get("threadId"),
|
||||||
|
"subject": message.get("subject") or headers.get("subject", "No Subject"),
|
||||||
|
"from": message.get("sender") or headers.get("from", "Unknown"),
|
||||||
|
"to": message.get("to") or headers.get("to", ""),
|
||||||
|
"date": message.get("messageTimestamp") or headers.get("date", ""),
|
||||||
|
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
|
||||||
|
"labels": message.get("labelIds", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _search_composio_gmail(
|
||||||
|
connector: SearchSourceConnector,
|
||||||
|
user_id: str,
|
||||||
|
query: str,
|
||||||
|
max_results: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
service = ComposioService()
|
||||||
|
messages, _next_token, _estimate, error = await service.get_gmail_messages(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
query=query,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
emails = [_format_gmail_summary(message) for message in messages]
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"emails": emails,
|
||||||
|
"total": len(emails),
|
||||||
|
"message": "No emails found." if not emails else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_search_gmail_tool(
|
def create_search_gmail_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the search_gmail tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured search_gmail tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_gmail(
|
async def search_gmail(
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -90,83 +159,92 @@ def create_search_gmail_tool(
|
||||||
Dictionary with status and a list of email summaries including
|
Dictionary with status and a list of email summaries including
|
||||||
message_id, subject, from, date, snippet.
|
message_id, subject, from, date, snippet.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||||
|
|
||||||
max_results = min(max_results, 20)
|
max_results = min(max_results, 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
result = await db_session.execute(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connector = result.scalars().first()
|
||||||
connector = result.scalars().first()
|
if not connector:
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
|
||||||
|
|
||||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
|
||||||
|
|
||||||
gmail = GoogleGmailConnector(
|
|
||||||
credentials=creds,
|
|
||||||
session=db_session,
|
|
||||||
user_id=user_id,
|
|
||||||
connector_id=connector.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages_list, error = await gmail.get_messages_list(
|
|
||||||
max_results=max_results, query=query
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
if (
|
|
||||||
"re-authenticate" in error.lower()
|
|
||||||
or "authentication failed" in error.lower()
|
|
||||||
):
|
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "error",
|
||||||
"message": error,
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
}
|
||||||
return {"status": "error", "message": error}
|
|
||||||
|
|
||||||
if not messages_list:
|
if (
|
||||||
return {
|
connector.connector_type
|
||||||
"status": "success",
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
"emails": [],
|
):
|
||||||
"total": 0,
|
return await _search_composio_gmail(
|
||||||
"message": "No emails found.",
|
connector, str(user_id), query, max_results
|
||||||
}
|
)
|
||||||
|
|
||||||
emails = []
|
creds = _build_credentials(connector)
|
||||||
for msg in messages_list:
|
|
||||||
detail, err = await gmail.get_message_details(msg["id"])
|
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
if err:
|
|
||||||
continue
|
gmail = GoogleGmailConnector(
|
||||||
headers = {
|
credentials=creds,
|
||||||
h["name"].lower(): h["value"]
|
session=db_session,
|
||||||
for h in detail.get("payload", {}).get("headers", [])
|
user_id=user_id,
|
||||||
}
|
connector_id=connector.id,
|
||||||
emails.append(
|
|
||||||
{
|
|
||||||
"message_id": detail.get("id"),
|
|
||||||
"thread_id": detail.get("threadId"),
|
|
||||||
"subject": headers.get("subject", "No Subject"),
|
|
||||||
"from": headers.get("from", "Unknown"),
|
|
||||||
"to": headers.get("to", ""),
|
|
||||||
"date": headers.get("date", ""),
|
|
||||||
"snippet": detail.get("snippet", ""),
|
|
||||||
"labels": detail.get("labelIds", []),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
messages_list, error = await gmail.get_messages_list(
|
||||||
|
max_results=max_results, query=query
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
if (
|
||||||
|
"re-authenticate" in error.lower()
|
||||||
|
or "authentication failed" in error.lower()
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error,
|
||||||
|
"connector_type": "gmail",
|
||||||
|
}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
if not messages_list:
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"emails": [],
|
||||||
|
"total": 0,
|
||||||
|
"message": "No emails found.",
|
||||||
|
}
|
||||||
|
|
||||||
|
emails = []
|
||||||
|
for msg in messages_list:
|
||||||
|
detail, err = await gmail.get_message_details(msg["id"])
|
||||||
|
if err:
|
||||||
|
continue
|
||||||
|
headers = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in detail.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
emails.append(
|
||||||
|
{
|
||||||
|
"message_id": detail.get("id"),
|
||||||
|
"thread_id": detail.get("threadId"),
|
||||||
|
"subject": headers.get("subject", "No Subject"),
|
||||||
|
"from": headers.get("from", "Unknown"),
|
||||||
|
"to": headers.get("to", ""),
|
||||||
|
"date": headers.get("date", ""),
|
||||||
|
"snippet": detail.get("snippet", ""),
|
||||||
|
"labels": detail.get("labelIds", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_gmail_email(
|
async def send_gmail_email(
|
||||||
to: str,
|
to: str,
|
||||||
|
|
@ -58,247 +76,277 @@ def create_send_gmail_email_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
|
||||||
return {"status": "error", "message": context["error"]}
|
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
logger.warning("All Gmail accounts have expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="gmail_email_send",
|
|
||||||
tool_name="send_gmail_email",
|
|
||||||
params={
|
|
||||||
"to": to,
|
|
||||||
"subject": subject,
|
|
||||||
"body": body,
|
|
||||||
"cc": cc,
|
|
||||||
"bcc": bcc,
|
|
||||||
"connector_id": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_to = result.params.get("to", to)
|
|
||||||
final_subject = result.params.get("subject", subject)
|
|
||||||
final_body = result.params.get("body", body)
|
|
||||||
final_cc = result.params.get("cc", cc)
|
|
||||||
final_bcc = result.params.get("bcc", bcc)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_gmail_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
return {
|
logger.error(
|
||||||
"status": "error",
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
)
|
||||||
)
|
return {"status": "error", "message": context["error"]}
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
accounts = context.get("accounts", [])
|
||||||
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
|
logger.warning("All Gmail accounts have expired authentication")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "auth_error",
|
||||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "gmail",
|
||||||
}
|
}
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||||
)
|
)
|
||||||
|
result = request_approval(
|
||||||
|
action_type="gmail_email_send",
|
||||||
|
tool_name="send_gmail_email",
|
||||||
|
params={
|
||||||
|
"to": to,
|
||||||
|
"subject": subject,
|
||||||
|
"body": body,
|
||||||
|
"cc": cc,
|
||||||
|
"bcc": bcc,
|
||||||
|
"connector_id": None,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if result.rejected:
|
||||||
connector.connector_type
|
return {
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
"status": "rejected",
|
||||||
):
|
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
}
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
final_to = result.params.get("to", to)
|
||||||
if cca_id:
|
final_subject = result.params.get("subject", subject)
|
||||||
creds = build_composio_credentials(cca_id)
|
final_body = result.params.get("body", body)
|
||||||
|
final_cc = result.params.get("cc", cc)
|
||||||
|
final_bcc = result.params.get("bcc", bcc)
|
||||||
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_gmail_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
if final_connector_id is not None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
else:
|
else:
|
||||||
return {
|
result = await db_session.execute(
|
||||||
"status": "error",
|
select(SearchSourceConnector).filter(
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
}
|
SearchSourceConnector.user_id == user_id,
|
||||||
else:
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
from google.oauth2.credentials import Credentials
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and config.SECRET_KEY:
|
|
||||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
||||||
if config_data.get("token"):
|
|
||||||
config_data["token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["token"]
|
|
||||||
)
|
)
|
||||||
if config_data.get("refresh_token"):
|
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["refresh_token"]
|
|
||||||
)
|
|
||||||
if config_data.get("client_secret"):
|
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
||||||
config_data["client_secret"]
|
|
||||||
)
|
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
|
||||||
message["to"] = final_to
|
|
||||||
message["subject"] = final_subject
|
|
||||||
if final_cc:
|
|
||||||
message["cc"] = final_cc
|
|
||||||
if final_bcc:
|
|
||||||
message["bcc"] = final_bcc
|
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
|
||||||
|
|
||||||
try:
|
|
||||||
sent = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
gmail_service.users()
|
|
||||||
.messages()
|
|
||||||
.send(userId="me", body={"raw": raw})
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
|
||||||
)
|
)
|
||||||
try:
|
connector = result.scalars().first()
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
_res = await db_session.execute(
|
logger.info(
|
||||||
select(SearchSourceConnector).where(
|
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||||
SearchSourceConnector.id == actual_connector_id
|
)
|
||||||
|
|
||||||
|
is_composio_gmail = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_gmail:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if config_data.get("token"):
|
||||||
|
config_data["token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["token"]
|
||||||
)
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config_data.get("client_secret"):
|
||||||
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["client_secret"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MIMEText(final_body)
|
||||||
|
message["to"] = final_to
|
||||||
|
message["subject"] = final_subject
|
||||||
|
if final_cc:
|
||||||
|
message["cc"] = final_cc
|
||||||
|
if final_bcc:
|
||||||
|
message["bcc"] = final_bcc
|
||||||
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
)
|
)
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
sent, error = await execute_composio_gmail_tool(
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
connector,
|
||||||
flag_modified(_conn, "config")
|
user_id,
|
||||||
await db_session.commit()
|
"GMAIL_SEND_EMAIL",
|
||||||
except Exception:
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(sent, dict):
|
||||||
|
sent = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
sent = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.messages()
|
||||||
|
.send(userId="me", body={"raw": raw})
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to persist auth_expired for connector %s",
|
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
return {
|
try:
|
||||||
"status": "insufficient_permissions",
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
"connector_id": actual_connector_id,
|
|
||||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
_res = await db_session.execute(
|
||||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
select(SearchSourceConnector).where(
|
||||||
)
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
kb_message_suffix = ""
|
logger.info(
|
||||||
try:
|
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||||
from app.services.gmail import GmailKBSyncService
|
|
||||||
|
|
||||||
kb_service = GmailKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_create(
|
|
||||||
message_id=sent.get("id", ""),
|
|
||||||
thread_id=sent.get("threadId", ""),
|
|
||||||
subject=final_subject,
|
|
||||||
sender="me",
|
|
||||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
body_text=final_body,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
|
||||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
kb_message_suffix = ""
|
||||||
"status": "success",
|
try:
|
||||||
"message_id": sent.get("id"),
|
from app.services.gmail import GmailKBSyncService
|
||||||
"thread_id": sent.get("threadId"),
|
|
||||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
kb_service = GmailKBSyncService(db_session)
|
||||||
}
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
message_id=sent.get("id", ""),
|
||||||
|
thread_id=sent.get("threadId", ""),
|
||||||
|
subject=final_subject,
|
||||||
|
sender="me",
|
||||||
|
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
body_text=final_body,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": sent.get("id"),
|
||||||
|
"thread_id": sent.get("threadId"),
|
||||||
|
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the trash_gmail_email tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured trash_gmail_email tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def trash_gmail_email(
|
async def trash_gmail_email(
|
||||||
email_subject_or_id: str,
|
email_subject_or_id: str,
|
||||||
|
|
@ -55,244 +73,261 @@ def create_trash_gmail_email_tool(
|
||||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_trash_context(
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
search_space_id, user_id, email_subject_or_id
|
context = await metadata_service.get_trash_context(
|
||||||
)
|
search_space_id, user_id, email_subject_or_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Email not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Gmail account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
|
||||||
|
|
||||||
email = context["email"]
|
if "error" in context:
|
||||||
message_id = email["message_id"]
|
error_msg = context["error"]
|
||||||
document_id = email.get("document_id")
|
if "not found" in error_msg.lower():
|
||||||
connector_id_from_context = context["account"]["id"]
|
logger.warning(f"Email not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
if not message_id:
|
account = context.get("account", {})
|
||||||
return {
|
if account.get("auth_expired"):
|
||||||
"status": "error",
|
logger.warning(
|
||||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
"Gmail account %s has expired authentication",
|
||||||
}
|
account.get("id"),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "gmail",
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(
|
email = context["email"]
|
||||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
message_id = email["message_id"]
|
||||||
)
|
document_id = email.get("document_id")
|
||||||
result = request_approval(
|
connector_id_from_context = context["account"]["id"]
|
||||||
action_type="gmail_email_trash",
|
|
||||||
tool_name="trash_gmail_email",
|
|
||||||
params={
|
|
||||||
"message_id": message_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
if not message_id:
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_message_id = result.params.get("message_id", message_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this email.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_gmail_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
from google.oauth2.credentials import Credentials
|
|
||||||
|
|
||||||
from app.config import config
|
logger.info(
|
||||||
from app.utils.oauth_security import TokenEncryption
|
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||||
|
)
|
||||||
config_data = dict(connector.config)
|
result = request_approval(
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
action_type="gmail_email_trash",
|
||||||
if token_encrypted and config.SECRET_KEY:
|
tool_name="trash_gmail_email",
|
||||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
params={
|
||||||
if config_data.get("token"):
|
"message_id": message_id,
|
||||||
config_data["token"] = token_encryption.decrypt_token(
|
"connector_id": connector_id_from_context,
|
||||||
config_data["token"]
|
"delete_from_kb": delete_from_kb,
|
||||||
)
|
},
|
||||||
if config_data.get("refresh_token"):
|
context=context,
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["refresh_token"]
|
|
||||||
)
|
|
||||||
if config_data.get("client_secret"):
|
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
||||||
config_data["client_secret"]
|
|
||||||
)
|
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
if result.rejected:
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
gmail_service.users()
|
|
||||||
.messages()
|
|
||||||
.trash(userId="me", id=final_message_id)
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
if not connector.config.get("auth_expired"):
|
|
||||||
connector.config = {
|
|
||||||
**connector.config,
|
|
||||||
"auth_expired": True,
|
|
||||||
}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
connector.id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "rejected",
|
||||||
"connector_id": connector.id,
|
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
final_message_id = result.params.get("message_id", message_id)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
trash_result: dict[str, Any] = {
|
"connector_id", connector_id_from_context
|
||||||
"status": "success",
|
)
|
||||||
"message_id": final_message_id,
|
final_delete_from_kb = result.params.get(
|
||||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
"delete_from_kb", delete_from_kb
|
||||||
}
|
|
||||||
|
|
||||||
deleted_from_kb = False
|
|
||||||
if final_delete_from_kb and document_id:
|
|
||||||
try:
|
|
||||||
from app.db import Document
|
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
)
|
|
||||||
document = doc_result.scalars().first()
|
|
||||||
if document:
|
|
||||||
await db_session.delete(document)
|
|
||||||
await db_session.commit()
|
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
trash_result["warning"] = (
|
|
||||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
|
||||||
if deleted_from_kb:
|
|
||||||
trash_result["message"] = (
|
|
||||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return trash_result
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this email.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_gmail_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_composio_gmail = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_gmail:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if config_data.get("token"):
|
||||||
|
config_data["token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["token"]
|
||||||
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config_data.get("client_secret"):
|
||||||
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["client_secret"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
_trashed, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_MOVE_TO_TRASH",
|
||||||
|
{"user_id": "me", "message_id": final_message_id},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.messages()
|
||||||
|
.trash(userId="me", id=final_message_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
if not connector.config.get("auth_expired"):
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": connector.id,
|
||||||
|
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||||
|
|
||||||
|
trash_result: dict[str, Any] = {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": final_message_id,
|
||||||
|
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_from_kb = False
|
||||||
|
if final_delete_from_kb and document_id:
|
||||||
|
try:
|
||||||
|
from app.db import Document
|
||||||
|
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
trash_result["warning"] = (
|
||||||
|
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
trash_result["message"] = (
|
||||||
|
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return trash_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.gmail import GmailToolMetadataService
|
from app.services.gmail import GmailToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_gmail_draft tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_gmail_draft tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_gmail_draft(
|
async def update_gmail_draft(
|
||||||
draft_subject_or_id: str,
|
draft_subject_or_id: str,
|
||||||
|
|
@ -76,294 +94,329 @@ def create_update_gmail_draft_tool(
|
||||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Gmail tool not properly configured. Please contact support.",
|
"message": "Gmail tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GmailToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = GmailToolMetadataService(db_session)
|
||||||
search_space_id, user_id, draft_subject_or_id
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, draft_subject_or_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Draft not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Gmail account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "gmail",
|
|
||||||
}
|
|
||||||
|
|
||||||
email = context["email"]
|
|
||||||
message_id = email["message_id"]
|
|
||||||
document_id = email.get("document_id")
|
|
||||||
connector_id_from_context = account["id"]
|
|
||||||
draft_id_from_context = context.get("draft_id")
|
|
||||||
|
|
||||||
original_subject = email.get("subject", draft_subject_or_id)
|
|
||||||
final_subject_default = subject if subject else original_subject
|
|
||||||
final_to_default = to if to else ""
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
|
||||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="gmail_draft_update",
|
|
||||||
tool_name="update_gmail_draft",
|
|
||||||
params={
|
|
||||||
"message_id": message_id,
|
|
||||||
"draft_id": draft_id_from_context,
|
|
||||||
"to": final_to_default,
|
|
||||||
"subject": final_subject_default,
|
|
||||||
"body": body,
|
|
||||||
"cc": cc,
|
|
||||||
"bcc": bcc,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_to = result.params.get("to", final_to_default)
|
|
||||||
final_subject = result.params.get("subject", final_subject_default)
|
|
||||||
final_body = result.params.get("body", body)
|
|
||||||
final_cc = result.params.get("cc", cc)
|
|
||||||
final_bcc = result.params.get("bcc", bcc)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this draft.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_gmail_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
from google.oauth2.credentials import Credentials
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and config.SECRET_KEY:
|
|
||||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
||||||
if config_data.get("token"):
|
|
||||||
config_data["token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["token"]
|
|
||||||
)
|
|
||||||
if config_data.get("refresh_token"):
|
|
||||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
|
||||||
config_data["refresh_token"]
|
|
||||||
)
|
|
||||||
if config_data.get("client_secret"):
|
|
||||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
|
||||||
config_data["client_secret"]
|
|
||||||
)
|
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from googleapiclient.discovery import build
|
if "error" in context:
|
||||||
|
error_msg = context["error"]
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
logger.warning(f"Draft not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
gmail_service = build("gmail", "v1", credentials=creds)
|
account = context.get("account", {})
|
||||||
|
if account.get("auth_expired"):
|
||||||
# Resolve draft_id if not already available
|
|
||||||
if not final_draft_id:
|
|
||||||
logger.info(
|
|
||||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
|
||||||
)
|
|
||||||
final_draft_id = await _find_draft_id_by_message(
|
|
||||||
gmail_service, message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not final_draft_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": (
|
|
||||||
"Could not find this draft in Gmail. "
|
|
||||||
"It may have already been sent or deleted."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
message = MIMEText(final_body)
|
|
||||||
if final_to:
|
|
||||||
message["to"] = final_to
|
|
||||||
message["subject"] = final_subject
|
|
||||||
if final_cc:
|
|
||||||
message["cc"] = final_cc
|
|
||||||
if final_bcc:
|
|
||||||
message["bcc"] = final_bcc
|
|
||||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
|
||||||
|
|
||||||
try:
|
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
gmail_service.users()
|
|
||||||
.drafts()
|
|
||||||
.update(
|
|
||||||
userId="me",
|
|
||||||
id=final_draft_id,
|
|
||||||
body={"message": {"raw": raw}},
|
|
||||||
)
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
"Gmail account %s has expired authentication",
|
||||||
|
account.get("id"),
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
if not connector.config.get("auth_expired"):
|
|
||||||
connector.config = {
|
|
||||||
**connector.config,
|
|
||||||
"auth_expired": True,
|
|
||||||
}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
connector.id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "auth_error",
|
||||||
"connector_id": connector.id,
|
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
"connector_type": "gmail",
|
||||||
}
|
}
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
|
||||||
|
email = context["email"]
|
||||||
|
message_id = email["message_id"]
|
||||||
|
document_id = email.get("document_id")
|
||||||
|
connector_id_from_context = account["id"]
|
||||||
|
draft_id_from_context = context.get("draft_id")
|
||||||
|
|
||||||
|
original_subject = email.get("subject", draft_subject_or_id)
|
||||||
|
final_subject_default = subject if subject else original_subject
|
||||||
|
final_to_default = to if to else ""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||||
|
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||||
|
)
|
||||||
|
result = request_approval(
|
||||||
|
action_type="gmail_draft_update",
|
||||||
|
tool_name="update_gmail_draft",
|
||||||
|
params={
|
||||||
|
"message_id": message_id,
|
||||||
|
"draft_id": draft_id_from_context,
|
||||||
|
"to": final_to_default,
|
||||||
|
"subject": final_subject_default,
|
||||||
|
"body": body,
|
||||||
|
"cc": cc,
|
||||||
|
"bcc": bcc,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_to = result.params.get("to", final_to_default)
|
||||||
|
final_subject = result.params.get("subject", final_subject_default)
|
||||||
|
final_body = result.params.get("body", body)
|
||||||
|
final_cc = result.params.get("cc", cc)
|
||||||
|
final_bcc = result.params.get("bcc", bcc)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
"message": "No connector found for this draft.",
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
kb_message_suffix = ""
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
if document_id:
|
|
||||||
try:
|
|
||||||
from sqlalchemy.future import select as sa_select
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
from app.db import Document
|
_gmail_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
sa_select(Document).filter(Document.id == document_id)
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
)
|
||||||
if document:
|
connector = result.scalars().first()
|
||||||
document.source_markdown = final_body
|
if not connector:
|
||||||
document.title = final_subject
|
return {
|
||||||
meta = dict(document.document_metadata or {})
|
"status": "error",
|
||||||
meta["subject"] = final_subject
|
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
}
|
||||||
updated_msg = updated.get("message", {})
|
|
||||||
if updated_msg.get("id"):
|
logger.info(
|
||||||
meta["message_id"] = updated_msg["id"]
|
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||||
document.document_metadata = meta
|
)
|
||||||
flag_modified(document, "document_metadata")
|
|
||||||
await db_session.commit()
|
is_composio_gmail = (
|
||||||
kb_message_suffix = (
|
connector.connector_type
|
||||||
" Your knowledge base has also been updated."
|
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||||
)
|
)
|
||||||
logger.info(
|
if is_composio_gmail:
|
||||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if config_data.get("token"):
|
||||||
|
config_data["token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["token"]
|
||||||
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config_data.get("client_secret"):
|
||||||
|
config_data["client_secret"] = (
|
||||||
|
token_encryption.decrypt_token(
|
||||||
|
config_data["client_secret"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve draft_id if not already available
|
||||||
|
if not final_draft_id:
|
||||||
|
logger.info(
|
||||||
|
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||||
|
)
|
||||||
|
if is_composio_gmail:
|
||||||
|
final_draft_id = await _find_composio_draft_id_by_message(
|
||||||
|
connector, user_id, message_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
from googleapiclient.discovery import build
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
|
||||||
await db_session.rollback()
|
|
||||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
"status": "success",
|
final_draft_id = await _find_draft_id_by_message(
|
||||||
"draft_id": updated.get("id"),
|
gmail_service, message_id
|
||||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
)
|
||||||
}
|
|
||||||
|
if not final_draft_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": (
|
||||||
|
"Could not find this draft in Gmail. "
|
||||||
|
"It may have already been sent or deleted."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
message = MIMEText(final_body)
|
||||||
|
if final_to:
|
||||||
|
message["to"] = final_to
|
||||||
|
message["subject"] = final_subject
|
||||||
|
if final_cc:
|
||||||
|
message["cc"] = final_cc
|
||||||
|
if final_bcc:
|
||||||
|
message["bcc"] = final_bcc
|
||||||
|
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_gmail:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
split_recipients,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated, error = await execute_composio_gmail_tool(
|
||||||
|
connector,
|
||||||
|
user_id,
|
||||||
|
"GMAIL_UPDATE_DRAFT",
|
||||||
|
{
|
||||||
|
"user_id": "me",
|
||||||
|
"draft_id": final_draft_id,
|
||||||
|
"recipient_email": final_to,
|
||||||
|
"subject": final_subject,
|
||||||
|
"body": final_body,
|
||||||
|
"cc": split_recipients(final_cc),
|
||||||
|
"bcc": split_recipients(final_bcc),
|
||||||
|
"is_html": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if not isinstance(updated, dict):
|
||||||
|
updated = {}
|
||||||
|
else:
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
gmail_service = build("gmail", "v1", credentials=creds)
|
||||||
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
gmail_service.users()
|
||||||
|
.drafts()
|
||||||
|
.update(
|
||||||
|
userId="me",
|
||||||
|
id=final_draft_id,
|
||||||
|
body={"message": {"raw": raw}},
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
if not connector.config.get("auth_expired"):
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": connector.id,
|
||||||
|
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
if document_id:
|
||||||
|
try:
|
||||||
|
from sqlalchemy.future import select as sa_select
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.db import Document
|
||||||
|
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
sa_select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
document.source_markdown = final_body
|
||||||
|
document.title = final_subject
|
||||||
|
meta = dict(document.document_metadata or {})
|
||||||
|
meta["subject"] = final_subject
|
||||||
|
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||||
|
updated_msg = updated.get("message", {})
|
||||||
|
if updated_msg.get("id"):
|
||||||
|
meta["message_id"] = updated_msg["id"]
|
||||||
|
document.document_metadata = meta
|
||||||
|
flag_modified(document, "document_metadata")
|
||||||
|
await db_session.commit()
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||||
|
await db_session.rollback()
|
||||||
|
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"draft_id": updated.get("id"),
|
||||||
|
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_composio_draft_id_by_message(
|
||||||
|
connector: Any, user_id: str, message_id: str
|
||||||
|
) -> str | None:
|
||||||
|
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||||
|
execute_composio_gmail_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
page_token = ""
|
||||||
|
while True:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"user_id": "me",
|
||||||
|
"max_results": 100,
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
if page_token:
|
||||||
|
params["page_token"] = page_token
|
||||||
|
|
||||||
|
data, error = await execute_composio_gmail_tool(
|
||||||
|
connector, user_id, "GMAIL_LIST_DRAFTS", params
|
||||||
|
)
|
||||||
|
if error or not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for draft in data.get("drafts", []):
|
||||||
|
if draft.get("message", {}).get("id") == message_id:
|
||||||
|
return draft.get("id")
|
||||||
|
|
||||||
|
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||||
|
if not page_token:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_calendar_event(
|
async def create_calendar_event(
|
||||||
summary: str,
|
summary: str,
|
||||||
|
|
@ -60,254 +78,294 @@ def create_create_calendar_event_tool(
|
||||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
|
||||||
return {"status": "error", "message": context["error"]}
|
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
logger.warning(
|
|
||||||
"All Google Calendar accounts have expired authentication"
|
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "google_calendar",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
if "error" in context:
|
||||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
logger.error(
|
||||||
)
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
result = request_approval(
|
|
||||||
action_type="google_calendar_event_creation",
|
|
||||||
tool_name="create_calendar_event",
|
|
||||||
params={
|
|
||||||
"summary": summary,
|
|
||||||
"start_datetime": start_datetime,
|
|
||||||
"end_datetime": end_datetime,
|
|
||||||
"description": description,
|
|
||||||
"location": location,
|
|
||||||
"attendees": attendees,
|
|
||||||
"timezone": context.get("timezone"),
|
|
||||||
"connector_id": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_summary = result.params.get("summary", summary)
|
|
||||||
final_start_datetime = result.params.get("start_datetime", start_datetime)
|
|
||||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
|
||||||
final_description = result.params.get("description", description)
|
|
||||||
final_location = result.params.get("location", location)
|
|
||||||
final_attendees = result.params.get("attendees", attendees)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
|
|
||||||
if not final_summary or not final_summary.strip():
|
|
||||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_calendar_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
|
||||||
)
|
)
|
||||||
)
|
return {"status": "error", "message": context["error"]}
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
accounts = context.get("accounts", [])
|
||||||
return {
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
"status": "error",
|
logger.warning(
|
||||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
"All Google Calendar accounts have expired authentication"
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "google_calendar",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
result = request_approval(
|
||||||
if not connector:
|
action_type="google_calendar_event_creation",
|
||||||
return {
|
tool_name="create_calendar_event",
|
||||||
"status": "error",
|
params={
|
||||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
"summary": summary,
|
||||||
}
|
"start_datetime": start_datetime,
|
||||||
actual_connector_id = connector.id
|
"end_datetime": end_datetime,
|
||||||
|
"description": description,
|
||||||
logger.info(
|
"location": location,
|
||||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
"attendees": attendees,
|
||||||
)
|
"timezone": context.get("timezone"),
|
||||||
|
"connector_id": None,
|
||||||
if (
|
},
|
||||||
connector.connector_type
|
context=context,
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
|
|
||||||
from app.config import config as app_config
|
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and app_config.SECRET_KEY:
|
|
||||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
|
||||||
for key in ("token", "refresh_token", "client_secret"):
|
|
||||||
if config_data.get(key):
|
|
||||||
config_data[key] = token_encryption.decrypt_token(
|
|
||||||
config_data[key]
|
|
||||||
)
|
|
||||||
|
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
if result.rejected:
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
return {
|
||||||
)
|
"status": "rejected",
|
||||||
|
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
tz = context.get("timezone", "UTC")
|
final_summary = result.params.get("summary", summary)
|
||||||
event_body: dict[str, Any] = {
|
final_start_datetime = result.params.get(
|
||||||
"summary": final_summary,
|
"start_datetime", start_datetime
|
||||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
)
|
||||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||||
}
|
final_description = result.params.get("description", description)
|
||||||
if final_description:
|
final_location = result.params.get("location", location)
|
||||||
event_body["description"] = final_description
|
final_attendees = result.params.get("attendees", attendees)
|
||||||
if final_location:
|
final_connector_id = result.params.get("connector_id")
|
||||||
event_body["location"] = final_location
|
|
||||||
if final_attendees:
|
if not final_summary or not final_summary.strip():
|
||||||
event_body["attendees"] = [
|
return {
|
||||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
"status": "error",
|
||||||
|
"message": "Event summary cannot be empty.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_calendar_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
if final_connector_id is not None:
|
||||||
created = await asyncio.get_event_loop().run_in_executor(
|
result = await db_session.execute(
|
||||||
None,
|
select(SearchSourceConnector).filter(
|
||||||
lambda: (
|
SearchSourceConnector.id == final_connector_id,
|
||||||
service.events()
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
.insert(calendarId="primary", body=event_body)
|
SearchSourceConnector.user_id == user_id,
|
||||||
.execute()
|
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
|
||||||
)
|
)
|
||||||
try:
|
connector = result.scalars().first()
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
if not connector:
|
||||||
|
return {
|
||||||
_res = await db_session.execute(
|
"status": "error",
|
||||||
select(SearchSourceConnector).where(
|
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||||
SearchSourceConnector.id == actual_connector_id
|
}
|
||||||
)
|
actual_connector_id = connector.id
|
||||||
)
|
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "insufficient_permissions",
|
|
||||||
"connector_id": actual_connector_id,
|
|
||||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
kb_message_suffix = ""
|
|
||||||
try:
|
|
||||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
|
||||||
|
|
||||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_create(
|
|
||||||
event_id=created.get("id"),
|
|
||||||
event_summary=final_summary,
|
|
||||||
calendar_id="primary",
|
|
||||||
start_time=final_start_datetime,
|
|
||||||
end_time=final_end_datetime,
|
|
||||||
location=final_location,
|
|
||||||
html_link=created.get("htmlLink"),
|
|
||||||
description=final_description,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
result = await db_session.execute(
|
||||||
except Exception as kb_err:
|
select(SearchSourceConnector).filter(
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
return {
|
logger.info(
|
||||||
"status": "success",
|
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||||
"event_id": created.get("id"),
|
)
|
||||||
"html_link": created.get("htmlLink"),
|
|
||||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
is_composio_calendar = (
|
||||||
}
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_calendar:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and app_config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||||
|
for key in ("token", "refresh_token", "client_secret"):
|
||||||
|
if config_data.get(key):
|
||||||
|
config_data[key] = token_encryption.decrypt_token(
|
||||||
|
config_data[key]
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tz = context.get("timezone", "UTC")
|
||||||
|
event_body: dict[str, Any] = {
|
||||||
|
"summary": final_summary,
|
||||||
|
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||||
|
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||||
|
}
|
||||||
|
if final_description:
|
||||||
|
event_body["description"] = final_description
|
||||||
|
if final_location:
|
||||||
|
event_body["location"] = final_location
|
||||||
|
if final_attendees:
|
||||||
|
event_body["attendees"] = [
|
||||||
|
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_params = {
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"summary": final_summary,
|
||||||
|
"start_datetime": final_start_datetime,
|
||||||
|
"end_datetime": final_end_datetime,
|
||||||
|
"timezone": tz,
|
||||||
|
"attendees": final_attendees or [],
|
||||||
|
}
|
||||||
|
if final_description:
|
||||||
|
composio_params["description"] = final_description
|
||||||
|
if final_location:
|
||||||
|
composio_params["location"] = final_location
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
created = composio_result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
created = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.insert(calendarId="primary", body=event_body)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
_res = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).where(
|
||||||
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||||
|
|
||||||
|
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
event_id=created.get("id"),
|
||||||
|
event_summary=final_summary,
|
||||||
|
calendar_id="primary",
|
||||||
|
start_time=final_start_datetime,
|
||||||
|
end_time=final_end_datetime,
|
||||||
|
location=final_location,
|
||||||
|
html_link=created.get("htmlLink"),
|
||||||
|
description=final_description,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"event_id": created.get("id"),
|
||||||
|
"html_link": created.get("htmlLink"),
|
||||||
|
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_calendar_event(
|
async def delete_calendar_event(
|
||||||
event_title_or_id: str,
|
event_title_or_id: str,
|
||||||
|
|
@ -54,240 +72,258 @@ def create_delete_calendar_event_tool(
|
||||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_deletion_context(
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
search_space_id, user_id, event_title_or_id
|
context = await metadata_service.get_deletion_context(
|
||||||
)
|
search_space_id, user_id, event_title_or_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Event not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Google Calendar account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "google_calendar",
|
|
||||||
}
|
|
||||||
|
|
||||||
event = context["event"]
|
if "error" in context:
|
||||||
event_id = event["event_id"]
|
error_msg = context["error"]
|
||||||
document_id = event.get("document_id")
|
if "not found" in error_msg.lower():
|
||||||
connector_id_from_context = context["account"]["id"]
|
logger.warning(f"Event not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
if not event_id:
|
account = context.get("account", {})
|
||||||
return {
|
if account.get("auth_expired"):
|
||||||
"status": "error",
|
logger.warning(
|
||||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
"Google Calendar account %s has expired authentication",
|
||||||
}
|
account.get("id"),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "google_calendar",
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(
|
event = context["event"]
|
||||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
event_id = event["event_id"]
|
||||||
)
|
document_id = event.get("document_id")
|
||||||
result = request_approval(
|
connector_id_from_context = context["account"]["id"]
|
||||||
action_type="google_calendar_event_deletion",
|
|
||||||
tool_name="delete_calendar_event",
|
|
||||||
params={
|
|
||||||
"event_id": event_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
if not event_id:
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_event_id = result.params.get("event_id", event_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this event.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_calendar_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
|
|
||||||
from app.config import config as app_config
|
logger.info(
|
||||||
from app.utils.oauth_security import TokenEncryption
|
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||||
|
)
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
result = request_approval(
|
||||||
if token_encrypted and app_config.SECRET_KEY:
|
action_type="google_calendar_event_deletion",
|
||||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
tool_name="delete_calendar_event",
|
||||||
for key in ("token", "refresh_token", "client_secret"):
|
params={
|
||||||
if config_data.get(key):
|
"event_id": event_id,
|
||||||
config_data[key] = token_encryption.decrypt_token(
|
"connector_id": connector_id_from_context,
|
||||||
config_data[key]
|
"delete_from_kb": delete_from_kb,
|
||||||
)
|
},
|
||||||
|
context=context,
|
||||||
exp = config_data.get("expiry", "")
|
|
||||||
if exp:
|
|
||||||
exp = exp.replace("Z", "")
|
|
||||||
|
|
||||||
creds = Credentials(
|
|
||||||
token=config_data.get("token"),
|
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
if result.rejected:
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
service.events()
|
|
||||||
.delete(calendarId="primary", eventId=final_event_id)
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
_res = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).where(
|
|
||||||
SearchSourceConnector.id == actual_connector_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "rejected",
|
||||||
"connector_id": actual_connector_id,
|
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
final_event_id = result.params.get("event_id", event_id)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
delete_result: dict[str, Any] = {
|
"connector_id", connector_id_from_context
|
||||||
"status": "success",
|
)
|
||||||
"event_id": final_event_id,
|
final_delete_from_kb = result.params.get(
|
||||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
"delete_from_kb", delete_from_kb
|
||||||
}
|
|
||||||
|
|
||||||
deleted_from_kb = False
|
|
||||||
if final_delete_from_kb and document_id:
|
|
||||||
try:
|
|
||||||
from app.db import Document
|
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
)
|
|
||||||
document = doc_result.scalars().first()
|
|
||||||
if document:
|
|
||||||
await db_session.delete(document)
|
|
||||||
await db_session.commit()
|
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
delete_result["warning"] = (
|
|
||||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
|
||||||
if deleted_from_kb:
|
|
||||||
delete_result["message"] = (
|
|
||||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return delete_result
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this event.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_calendar_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_composio_calendar = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_calendar:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and app_config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||||
|
for key in ("token", "refresh_token", "client_secret"):
|
||||||
|
if config_data.get(key):
|
||||||
|
config_data[key] = token_encryption.decrypt_token(
|
||||||
|
config_data[key]
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||||
|
params={
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.delete(calendarId="primary", eventId=final_event_id)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
_res = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).where(
|
||||||
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||||
|
|
||||||
|
delete_result: dict[str, Any] = {
|
||||||
|
"status": "success",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_from_kb = False
|
||||||
|
if final_delete_from_kb and document_id:
|
||||||
|
try:
|
||||||
|
from app.db import Document
|
||||||
|
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
delete_result["warning"] = (
|
||||||
|
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
delete_result["message"] = (
|
||||||
|
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return delete_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||||
|
if "T" in value:
|
||||||
|
return value
|
||||||
|
time = "23:59:59" if is_end else "00:00:00"
|
||||||
|
return f"{value}T{time}Z"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
events = []
|
||||||
|
for ev in events_raw:
|
||||||
|
start = ev.get("start", {})
|
||||||
|
end = ev.get("end", {})
|
||||||
|
attendees_raw = ev.get("attendees", [])
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"event_id": ev.get("id"),
|
||||||
|
"summary": ev.get("summary", "No Title"),
|
||||||
|
"start": start.get("dateTime") or start.get("date", ""),
|
||||||
|
"end": end.get("dateTime") or end.get("date", ""),
|
||||||
|
"location": ev.get("location", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"html_link": ev.get("htmlLink", ""),
|
||||||
|
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||||
|
"status": ev.get("status", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
def create_search_calendar_events_tool(
|
def create_search_calendar_events_tool(
|
||||||
db_session: AsyncSession | None = None,
|
db_session: AsyncSession | None = None,
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the search_calendar_events tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured search_calendar_events tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_calendar_events(
|
async def search_calendar_events(
|
||||||
start_date: str,
|
start_date: str,
|
||||||
|
|
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
|
||||||
Dictionary with status and a list of events including
|
Dictionary with status and a list of events including
|
||||||
event_id, summary, start, end, location, attendees.
|
event_id, summary, start, end, location, attendees.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Calendar tool not properly configured.",
|
"message": "Calendar tool not properly configured.",
|
||||||
|
|
@ -47,76 +93,85 @@ def create_search_calendar_events_tool(
|
||||||
max_results = min(max_results, 50)
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
result = await db_session.execute(
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connector = result.scalars().first()
|
||||||
connector = result.scalars().first()
|
if not connector:
|
||||||
if not connector:
|
return {
|
||||||
return {
|
"status": "error",
|
||||||
"status": "error",
|
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
}
|
||||||
}
|
|
||||||
|
|
||||||
creds = _build_credentials(connector)
|
|
||||||
|
|
||||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
|
||||||
|
|
||||||
cal = GoogleCalendarConnector(
|
|
||||||
credentials=creds,
|
|
||||||
session=db_session,
|
|
||||||
user_id=user_id,
|
|
||||||
connector_id=connector.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
max_results=max_results,
|
|
||||||
)
|
|
||||||
|
|
||||||
if error:
|
|
||||||
if (
|
if (
|
||||||
"re-authenticate" in error.lower()
|
connector.connector_type
|
||||||
or "authentication failed" in error.lower()
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
):
|
):
|
||||||
return {
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
"status": "auth_error",
|
if not cca_id:
|
||||||
"message": error,
|
return {
|
||||||
"connector_type": "google_calendar",
|
"status": "error",
|
||||||
}
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
if "no events found" in error.lower():
|
}
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"events": [],
|
|
||||||
"total": 0,
|
|
||||||
"message": error,
|
|
||||||
}
|
|
||||||
return {"status": "error", "message": error}
|
|
||||||
|
|
||||||
events = []
|
from app.services.composio_service import ComposioService
|
||||||
for ev in events_raw:
|
|
||||||
start = ev.get("start", {})
|
|
||||||
end = ev.get("end", {})
|
|
||||||
attendees_raw = ev.get("attendees", [])
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"event_id": ev.get("id"),
|
|
||||||
"summary": ev.get("summary", "No Title"),
|
|
||||||
"start": start.get("dateTime") or start.get("date", ""),
|
|
||||||
"end": end.get("dateTime") or end.get("date", ""),
|
|
||||||
"location": ev.get("location", ""),
|
|
||||||
"description": ev.get("description", ""),
|
|
||||||
"html_link": ev.get("htmlLink", ""),
|
|
||||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
|
||||||
"status": ev.get("status", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "success", "events": events, "total": len(events)}
|
events_raw, error = await ComposioService().get_calendar_events(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||||
|
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
if not events_raw and not error:
|
||||||
|
error = "No events found in the specified date range."
|
||||||
|
else:
|
||||||
|
creds = _build_credentials(connector)
|
||||||
|
|
||||||
|
from app.connectors.google_calendar_connector import (
|
||||||
|
GoogleCalendarConnector,
|
||||||
|
)
|
||||||
|
|
||||||
|
cal = GoogleCalendarConnector(
|
||||||
|
credentials=creds,
|
||||||
|
session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
if (
|
||||||
|
"re-authenticate" in error.lower()
|
||||||
|
or "authentication failed" in error.lower()
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error,
|
||||||
|
"connector_type": "google_calendar",
|
||||||
|
}
|
||||||
|
if "no events found" in error.lower():
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"events": [],
|
||||||
|
"total": 0,
|
||||||
|
"message": error,
|
||||||
|
}
|
||||||
|
return {"status": "error", "message": error}
|
||||||
|
|
||||||
|
events = _format_calendar_events(events_raw)
|
||||||
|
|
||||||
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the update_calendar_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_calendar_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_calendar_event(
|
async def update_calendar_event(
|
||||||
event_title_or_id: str,
|
event_title_or_id: str,
|
||||||
|
|
@ -74,272 +92,317 @@ def create_update_calendar_event_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||||
search_space_id, user_id, event_title_or_id
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, event_title_or_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Event not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
logger.warning("Google Calendar account has expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "google_calendar",
|
|
||||||
}
|
|
||||||
|
|
||||||
event = context["event"]
|
|
||||||
event_id = event["event_id"]
|
|
||||||
document_id = event.get("document_id")
|
|
||||||
connector_id_from_context = context["account"]["id"]
|
|
||||||
|
|
||||||
if not event_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="google_calendar_event_update",
|
|
||||||
tool_name="update_calendar_event",
|
|
||||||
params={
|
|
||||||
"event_id": event_id,
|
|
||||||
"document_id": document_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"new_summary": new_summary,
|
|
||||||
"new_start_datetime": new_start_datetime,
|
|
||||||
"new_end_datetime": new_end_datetime,
|
|
||||||
"new_description": new_description,
|
|
||||||
"new_location": new_location,
|
|
||||||
"new_attendees": new_attendees,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_event_id = result.params.get("event_id", event_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_new_summary = result.params.get("new_summary", new_summary)
|
|
||||||
final_new_start_datetime = result.params.get(
|
|
||||||
"new_start_datetime", new_start_datetime
|
|
||||||
)
|
|
||||||
final_new_end_datetime = result.params.get(
|
|
||||||
"new_end_datetime", new_end_datetime
|
|
||||||
)
|
|
||||||
final_new_description = result.params.get(
|
|
||||||
"new_description", new_description
|
|
||||||
)
|
|
||||||
final_new_location = result.params.get("new_location", new_location)
|
|
||||||
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this event.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_calendar_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
|
|
||||||
actual_connector_id = connector.id
|
if "error" in context:
|
||||||
|
error_msg = context["error"]
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
logger.warning(f"Event not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
logger.info(
|
if context.get("auth_expired"):
|
||||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
logger.warning("Google Calendar account has expired authentication")
|
||||||
)
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "google_calendar",
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
event = context["event"]
|
||||||
connector.connector_type
|
event_id = event["event_id"]
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
document_id = event.get("document_id")
|
||||||
):
|
connector_id_from_context = context["account"]["id"]
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
if not event_id:
|
||||||
if cca_id:
|
|
||||||
creds = build_composio_credentials(cca_id)
|
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Composio connected account ID not found for this connector.",
|
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
config_data = dict(connector.config)
|
|
||||||
|
|
||||||
from app.config import config as app_config
|
logger.info(
|
||||||
from app.utils.oauth_security import TokenEncryption
|
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||||
|
)
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
result = request_approval(
|
||||||
if token_encrypted and app_config.SECRET_KEY:
|
action_type="google_calendar_event_update",
|
||||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
tool_name="update_calendar_event",
|
||||||
for key in ("token", "refresh_token", "client_secret"):
|
params={
|
||||||
if config_data.get(key):
|
"event_id": event_id,
|
||||||
config_data[key] = token_encryption.decrypt_token(
|
"document_id": document_id,
|
||||||
config_data[key]
|
"connector_id": connector_id_from_context,
|
||||||
)
|
"new_summary": new_summary,
|
||||||
|
"new_start_datetime": new_start_datetime,
|
||||||
exp = config_data.get("expiry", "")
|
"new_end_datetime": new_end_datetime,
|
||||||
if exp:
|
"new_description": new_description,
|
||||||
exp = exp.replace("Z", "")
|
"new_location": new_location,
|
||||||
|
"new_attendees": new_attendees,
|
||||||
creds = Credentials(
|
},
|
||||||
token=config_data.get("token"),
|
context=context,
|
||||||
refresh_token=config_data.get("refresh_token"),
|
|
||||||
token_uri=config_data.get("token_uri"),
|
|
||||||
client_id=config_data.get("client_id"),
|
|
||||||
client_secret=config_data.get("client_secret"),
|
|
||||||
scopes=config_data.get("scopes", []),
|
|
||||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
service = await asyncio.get_event_loop().run_in_executor(
|
if result.rejected:
|
||||||
None, lambda: build("calendar", "v3", credentials=creds)
|
return {
|
||||||
)
|
"status": "rejected",
|
||||||
|
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
update_body: dict[str, Any] = {}
|
final_event_id = result.params.get("event_id", event_id)
|
||||||
if final_new_summary is not None:
|
final_connector_id = result.params.get(
|
||||||
update_body["summary"] = final_new_summary
|
"connector_id", connector_id_from_context
|
||||||
if final_new_start_datetime is not None:
|
|
||||||
update_body["start"] = _build_time_body(
|
|
||||||
final_new_start_datetime, context
|
|
||||||
)
|
)
|
||||||
if final_new_end_datetime is not None:
|
final_new_summary = result.params.get("new_summary", new_summary)
|
||||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
final_new_start_datetime = result.params.get(
|
||||||
if final_new_description is not None:
|
"new_start_datetime", new_start_datetime
|
||||||
update_body["description"] = final_new_description
|
)
|
||||||
if final_new_location is not None:
|
final_new_end_datetime = result.params.get(
|
||||||
update_body["location"] = final_new_location
|
"new_end_datetime", new_end_datetime
|
||||||
if final_new_attendees is not None:
|
)
|
||||||
update_body["attendees"] = [
|
final_new_description = result.params.get(
|
||||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
"new_description", new_description
|
||||||
|
)
|
||||||
|
final_new_location = result.params.get("new_location", new_location)
|
||||||
|
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this event.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_calendar_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||||
]
|
]
|
||||||
|
|
||||||
if not update_body:
|
result = await db_session.execute(
|
||||||
return {
|
select(SearchSourceConnector).filter(
|
||||||
"status": "error",
|
SearchSourceConnector.id == final_connector_id,
|
||||||
"message": "No changes specified. Please provide at least one field to update.",
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
}
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||||
try:
|
)
|
||||||
updated = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: (
|
|
||||||
service.events()
|
|
||||||
.patch(
|
|
||||||
calendarId="primary",
|
|
||||||
eventId=final_event_id,
|
|
||||||
body=update_body,
|
|
||||||
)
|
|
||||||
.execute()
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
except Exception as api_err:
|
connector = result.scalars().first()
|
||||||
from googleapiclient.errors import HttpError
|
if not connector:
|
||||||
|
|
||||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
|
||||||
logger.warning(
|
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
_res = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).where(
|
|
||||||
SearchSourceConnector.id == actual_connector_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "error",
|
||||||
"connector_id": actual_connector_id,
|
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
kb_message_suffix = ""
|
logger.info(
|
||||||
if document_id is not None:
|
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||||
try:
|
)
|
||||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
|
||||||
|
|
||||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
is_composio_calendar = (
|
||||||
kb_result = await kb_service.sync_after_update(
|
connector.connector_type
|
||||||
document_id=document_id,
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||||
event_id=final_event_id,
|
)
|
||||||
connector_id=actual_connector_id,
|
if is_composio_calendar:
|
||||||
search_space_id=search_space_id,
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
user_id=user_id,
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this connector.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config_data = dict(connector.config)
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and app_config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||||
|
for key in ("token", "refresh_token", "client_secret"):
|
||||||
|
if config_data.get(key):
|
||||||
|
config_data[key] = token_encryption.decrypt_token(
|
||||||
|
config_data[key]
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = config_data.get("expiry", "")
|
||||||
|
if exp:
|
||||||
|
exp = exp.replace("Z", "")
|
||||||
|
|
||||||
|
creds = Credentials(
|
||||||
|
token=config_data.get("token"),
|
||||||
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
token_uri=config_data.get("token_uri"),
|
||||||
|
client_id=config_data.get("client_id"),
|
||||||
|
client_secret=config_data.get("client_secret"),
|
||||||
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = (
|
|
||||||
" Your knowledge base has also been updated."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
|
||||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
update_body: dict[str, Any] = {}
|
||||||
"status": "success",
|
if final_new_summary is not None:
|
||||||
"event_id": final_event_id,
|
update_body["summary"] = final_new_summary
|
||||||
"html_link": updated.get("htmlLink"),
|
if final_new_start_datetime is not None:
|
||||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
update_body["start"] = _build_time_body(
|
||||||
}
|
final_new_start_datetime, context
|
||||||
|
)
|
||||||
|
if final_new_end_datetime is not None:
|
||||||
|
update_body["end"] = _build_time_body(
|
||||||
|
final_new_end_datetime, context
|
||||||
|
)
|
||||||
|
if final_new_description is not None:
|
||||||
|
update_body["description"] = final_new_description
|
||||||
|
if final_new_location is not None:
|
||||||
|
update_body["location"] = final_new_location
|
||||||
|
if final_new_attendees is not None:
|
||||||
|
update_body["attendees"] = [
|
||||||
|
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
if not update_body:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No changes specified. Please provide at least one field to update.",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_composio_calendar:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
composio_params: dict[str, Any] = {
|
||||||
|
"calendar_id": "primary",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
}
|
||||||
|
if final_new_summary is not None:
|
||||||
|
composio_params["summary"] = final_new_summary
|
||||||
|
if final_new_start_datetime is not None:
|
||||||
|
composio_params["start_time"] = final_new_start_datetime
|
||||||
|
if final_new_end_datetime is not None:
|
||||||
|
composio_params["end_time"] = final_new_end_datetime
|
||||||
|
if final_new_description is not None:
|
||||||
|
composio_params["description"] = final_new_description
|
||||||
|
if final_new_location is not None:
|
||||||
|
composio_params["location"] = final_new_location
|
||||||
|
if final_new_attendees is not None:
|
||||||
|
composio_params["attendees"] = [
|
||||||
|
e.strip() for e in final_new_attendees if e.strip()
|
||||||
|
]
|
||||||
|
if not _is_date_only(
|
||||||
|
final_new_start_datetime or final_new_end_datetime or ""
|
||||||
|
):
|
||||||
|
composio_params["timezone"] = context.get("timezone", "UTC")
|
||||||
|
|
||||||
|
composio_result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||||
|
params=composio_params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not composio_result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
composio_result.get(
|
||||||
|
"error", "Unknown Composio Calendar error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
updated = composio_result.get("data", {})
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
updated = updated.get("data", updated)
|
||||||
|
if isinstance(updated, dict):
|
||||||
|
updated = updated.get("response_data", updated)
|
||||||
|
else:
|
||||||
|
service = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: build("calendar", "v3", credentials=creds)
|
||||||
|
)
|
||||||
|
updated = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: (
|
||||||
|
service.events()
|
||||||
|
.patch(
|
||||||
|
calendarId="primary",
|
||||||
|
eventId=final_event_id,
|
||||||
|
body=update_body,
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
_res = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).where(
|
||||||
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
if document_id is not None:
|
||||||
|
try:
|
||||||
|
from app.services.google_calendar import (
|
||||||
|
GoogleCalendarKBSyncService,
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
document_id=document_id,
|
||||||
|
event_id=final_event_id,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||||
|
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"event_id": final_event_id,
|
||||||
|
"html_link": updated.get("htmlLink"),
|
||||||
|
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_google_drive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Google Drive connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_google_drive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_google_drive_file(
|
async def create_google_drive_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
|
||||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||||
|
|
@ -78,195 +98,232 @@ def create_create_google_drive_file_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
if "error" in context:
|
if "error" in context:
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
logger.error(
|
||||||
return {"status": "error", "message": context["error"]}
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
logger.warning("All Google Drive accounts have expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "google_drive",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="google_drive_file_creation",
|
|
||||||
tool_name="create_google_drive_file",
|
|
||||||
params={
|
|
||||||
"name": name,
|
|
||||||
"file_type": file_type,
|
|
||||||
"content": content,
|
|
||||||
"connector_id": None,
|
|
||||||
"parent_folder_id": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_name = result.params.get("name", name)
|
|
||||||
final_file_type = result.params.get("file_type", file_type)
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
|
||||||
|
|
||||||
if not final_name or not final_name.strip():
|
|
||||||
return {"status": "error", "message": "File name cannot be empty."}
|
|
||||||
|
|
||||||
mime_type = _MIME_MAP.get(final_file_type)
|
|
||||||
if not mime_type:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Unsupported file type '{final_file_type}'.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_drive_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
|
||||||
)
|
)
|
||||||
)
|
return {"status": "error", "message": context["error"]}
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
accounts = context.get("accounts", [])
|
||||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
)
|
|
||||||
|
|
||||||
pre_built_creds = None
|
|
||||||
if (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
|
||||||
session=db_session,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
created = await client.create_file(
|
|
||||||
name=final_name,
|
|
||||||
mime_type=mime_type,
|
|
||||||
parent_folder_id=final_parent_folder_id,
|
|
||||||
content=final_content,
|
|
||||||
)
|
|
||||||
except HttpError as http_err:
|
|
||||||
if http_err.resp.status == 403:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
"All Google Drive accounts have expired authentication"
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
_res = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).where(
|
|
||||||
SearchSourceConnector.id == actual_connector_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_conn = _res.scalar_one_or_none()
|
|
||||||
if _conn and not _conn.config.get("auth_expired"):
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
actual_connector_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "auth_error",
|
||||||
"connector_id": actual_connector_id,
|
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
"connector_type": "google_drive",
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||||
)
|
)
|
||||||
|
result = request_approval(
|
||||||
kb_message_suffix = ""
|
action_type="google_drive_file_creation",
|
||||||
try:
|
tool_name="create_google_drive_file",
|
||||||
from app.services.google_drive import GoogleDriveKBSyncService
|
params={
|
||||||
|
"name": name,
|
||||||
kb_service = GoogleDriveKBSyncService(db_session)
|
"file_type": file_type,
|
||||||
kb_result = await kb_service.sync_after_create(
|
"content": content,
|
||||||
file_id=created.get("id"),
|
"connector_id": None,
|
||||||
file_name=created.get("name", final_name),
|
"parent_folder_id": None,
|
||||||
mime_type=mime_type,
|
},
|
||||||
web_view_link=created.get("webViewLink"),
|
context=context,
|
||||||
content=final_content,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
if result.rejected:
|
||||||
"status": "success",
|
return {
|
||||||
"file_id": created.get("id"),
|
"status": "rejected",
|
||||||
"name": created.get("name"),
|
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||||
"web_view_link": created.get("webViewLink"),
|
}
|
||||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
|
||||||
}
|
final_name = result.params.get("name", name)
|
||||||
|
final_file_type = result.params.get("file_type", file_type)
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||||
|
|
||||||
|
if not final_name or not final_name.strip():
|
||||||
|
return {"status": "error", "message": "File name cannot be empty."}
|
||||||
|
|
||||||
|
mime_type = _MIME_MAP.get(final_file_type)
|
||||||
|
if not mime_type:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Unsupported file type '{final_file_type}'.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_drive_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
if final_connector_id is not None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
else:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_composio_drive = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_drive:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
|
client = GoogleDriveClient(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if is_composio_drive:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"name": final_name,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"fields": "id,name,webViewLink,mimeType",
|
||||||
|
}
|
||||||
|
if final_parent_folder_id:
|
||||||
|
params["parents"] = [final_parent_folder_id]
|
||||||
|
if final_content:
|
||||||
|
params["description"] = final_content[:4096]
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_CREATE_FILE",
|
||||||
|
params=params,
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
created = result.get("data", {})
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("data", created)
|
||||||
|
if isinstance(created, dict):
|
||||||
|
created = created.get("response_data", created)
|
||||||
|
if not isinstance(created, dict):
|
||||||
|
created = {}
|
||||||
|
else:
|
||||||
|
created = await client.create_file(
|
||||||
|
name=final_name,
|
||||||
|
mime_type=mime_type,
|
||||||
|
parent_folder_id=final_parent_folder_id,
|
||||||
|
content=final_content,
|
||||||
|
)
|
||||||
|
except HttpError as http_err:
|
||||||
|
if http_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
_res = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).where(
|
||||||
|
SearchSourceConnector.id == actual_connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_conn = _res.scalar_one_or_none()
|
||||||
|
if _conn and not _conn.config.get("auth_expired"):
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
actual_connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.google_drive import GoogleDriveKBSyncService
|
||||||
|
|
||||||
|
kb_service = GoogleDriveKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
file_id=created.get("id"),
|
||||||
|
file_name=created.get("name", final_name),
|
||||||
|
mime_type=mime_type,
|
||||||
|
web_view_link=created.get("webViewLink"),
|
||||||
|
content=final_content,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": created.get("id"),
|
||||||
|
"name": created.get("name"),
|
||||||
|
"web_view_link": created.get("webViewLink"),
|
||||||
|
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_google_drive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Google Drive connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_google_drive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_google_drive_file(
|
async def delete_google_drive_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool(
|
||||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_trash_context(
|
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||||
search_space_id, user_id, file_name
|
context = await metadata_service.get_trash_context(
|
||||||
)
|
search_space_id, user_id, file_name
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"File not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Google Drive account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "google_drive",
|
|
||||||
}
|
|
||||||
|
|
||||||
file = context["file"]
|
if "error" in context:
|
||||||
file_id = file["file_id"]
|
error_msg = context["error"]
|
||||||
document_id = file.get("document_id")
|
if "not found" in error_msg.lower():
|
||||||
connector_id_from_context = context["account"]["id"]
|
logger.warning(f"File not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
if not file_id:
|
account = context.get("account", {})
|
||||||
return {
|
if account.get("auth_expired"):
|
||||||
"status": "error",
|
|
||||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="google_drive_file_trash",
|
|
||||||
tool_name="delete_google_drive_file",
|
|
||||||
params={
|
|
||||||
"file_id": file_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_file_id = result.params.get("file_id", file_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this file.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
_drive_types = [
|
|
||||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
|
||||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
pre_built_creds = None
|
|
||||||
if (
|
|
||||||
connector.connector_type
|
|
||||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
|
||||||
):
|
|
||||||
from app.utils.google_credentials import build_composio_credentials
|
|
||||||
|
|
||||||
cca_id = connector.config.get("composio_connected_account_id")
|
|
||||||
if cca_id:
|
|
||||||
pre_built_creds = build_composio_credentials(cca_id)
|
|
||||||
|
|
||||||
client = GoogleDriveClient(
|
|
||||||
session=db_session,
|
|
||||||
connector_id=connector.id,
|
|
||||||
credentials=pre_built_creds,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await client.trash_file(file_id=final_file_id)
|
|
||||||
except HttpError as http_err:
|
|
||||||
if http_err.resp.status == 403:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
"Google Drive account %s has expired authentication",
|
||||||
|
account.get("id"),
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
if not connector.config.get("auth_expired"):
|
|
||||||
connector.config = {
|
|
||||||
**connector.config,
|
|
||||||
"auth_expired": True,
|
|
||||||
}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to persist auth_expired for connector %s",
|
|
||||||
connector.id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "auth_error",
|
||||||
"connector_id": connector.id,
|
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
"connector_type": "google_drive",
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
file = context["file"]
|
||||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
file_id = file["file_id"]
|
||||||
)
|
document_id = file.get("document_id")
|
||||||
|
connector_id_from_context = context["account"]["id"]
|
||||||
|
|
||||||
trash_result: dict[str, Any] = {
|
if not file_id:
|
||||||
"status": "success",
|
return {
|
||||||
"file_id": final_file_id,
|
"status": "error",
|
||||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted_from_kb = False
|
logger.info(
|
||||||
if final_delete_from_kb and document_id:
|
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||||
try:
|
)
|
||||||
from app.db import Document
|
result = request_approval(
|
||||||
|
action_type="google_drive_file_trash",
|
||||||
doc_result = await db_session.execute(
|
tool_name="delete_google_drive_file",
|
||||||
select(Document).filter(Document.id == document_id)
|
params={
|
||||||
)
|
"file_id": file_id,
|
||||||
document = doc_result.scalars().first()
|
"connector_id": connector_id_from_context,
|
||||||
if document:
|
"delete_from_kb": delete_from_kb,
|
||||||
await db_session.delete(document)
|
},
|
||||||
await db_session.commit()
|
context=context,
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
trash_result["warning"] = (
|
|
||||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
|
||||||
if deleted_from_kb:
|
|
||||||
trash_result["message"] = (
|
|
||||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return trash_result
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_file_id = result.params.get("file_id", file_id)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this file.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
_drive_types = [
|
||||||
|
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_composio_drive = (
|
||||||
|
connector.connector_type
|
||||||
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||||
|
)
|
||||||
|
if is_composio_drive:
|
||||||
|
cca_id = connector.config.get("composio_connected_account_id")
|
||||||
|
if not cca_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Composio connected account ID not found for this Drive connector.",
|
||||||
|
}
|
||||||
|
|
||||||
|
client = GoogleDriveClient(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=connector.id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if is_composio_drive:
|
||||||
|
from app.services.composio_service import ComposioService
|
||||||
|
|
||||||
|
result = await ComposioService().execute_tool(
|
||||||
|
connected_account_id=cca_id,
|
||||||
|
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||||
|
params={"file_id": final_file_id},
|
||||||
|
entity_id=f"surfsense_{user_id}",
|
||||||
|
)
|
||||||
|
if not result.get("success"):
|
||||||
|
raise RuntimeError(
|
||||||
|
result.get("error", "Unknown Composio Drive error")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.trash_file(file_id=final_file_id)
|
||||||
|
except HttpError as http_err:
|
||||||
|
if http_err.resp.status == 403:
|
||||||
|
logger.warning(
|
||||||
|
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
if not connector.config.get("auth_expired"):
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to persist auth_expired for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": connector.id,
|
||||||
|
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result: dict[str, Any] = {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": final_file_id,
|
||||||
|
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_from_kb = False
|
||||||
|
if final_delete_from_kb and document_id:
|
||||||
|
try:
|
||||||
|
from app.db import Document
|
||||||
|
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
trash_result["warning"] = (
|
||||||
|
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
trash_result["message"] = (
|
||||||
|
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return trash_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
"create_gmail_draft",
|
"create_gmail_draft",
|
||||||
"update_gmail_draft",
|
"update_gmail_draft",
|
||||||
|
"create_calendar_event",
|
||||||
"create_notion_page",
|
"create_notion_page",
|
||||||
"create_confluence_page",
|
"create_confluence_page",
|
||||||
"create_google_drive_file",
|
"create_google_drive_file",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the create_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||||
|
keep the request's outer transaction free of long-running Jira API
|
||||||
|
blocking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_jira_issue(
|
async def create_jira_issue(
|
||||||
project_key: str,
|
project_key: str,
|
||||||
|
|
@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
|
||||||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
return {"status": "error", "message": context["error"]}
|
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Jira accounts need re-authentication.",
|
|
||||||
"connector_type": "jira",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="jira_issue_creation",
|
|
||||||
tool_name="create_jira_issue",
|
|
||||||
params={
|
|
||||||
"project_key": project_key,
|
|
||||||
"summary": summary,
|
|
||||||
"issue_type": issue_type,
|
|
||||||
"description": description,
|
|
||||||
"priority": priority,
|
|
||||||
"connector_id": connector_id,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_project_key = result.params.get("project_key", project_key)
|
|
||||||
final_summary = result.params.get("summary", summary)
|
|
||||||
final_issue_type = result.params.get("issue_type", issue_type)
|
|
||||||
final_description = result.params.get("description", description)
|
|
||||||
final_priority = result.params.get("priority", priority)
|
|
||||||
final_connector_id = result.params.get("connector_id", connector_id)
|
|
||||||
|
|
||||||
if not final_summary or not final_summary.strip():
|
|
||||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
|
||||||
if not final_project_key:
|
|
||||||
return {"status": "error", "message": "A project must be selected."}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
actual_connector_id = final_connector_id
|
|
||||||
if actual_connector_id is None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
return {"status": "error", "message": "No Jira connector found."}
|
return {"status": "error", "message": context["error"]}
|
||||||
actual_connector_id = connector.id
|
|
||||||
else:
|
accounts = context.get("accounts", [])
|
||||||
result = await db_session.execute(
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
select(SearchSourceConnector).filter(
|
return {
|
||||||
SearchSourceConnector.id == actual_connector_id,
|
"status": "auth_error",
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
"message": "All connected Jira accounts need re-authentication.",
|
||||||
SearchSourceConnector.user_id == user_id,
|
"connector_type": "jira",
|
||||||
SearchSourceConnector.connector_type
|
}
|
||||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
|
||||||
)
|
result = request_approval(
|
||||||
|
action_type="jira_issue_creation",
|
||||||
|
tool_name="create_jira_issue",
|
||||||
|
params={
|
||||||
|
"project_key": project_key,
|
||||||
|
"summary": summary,
|
||||||
|
"issue_type": issue_type,
|
||||||
|
"description": description,
|
||||||
|
"priority": priority,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_project_key = result.params.get("project_key", project_key)
|
||||||
|
final_summary = result.params.get("summary", summary)
|
||||||
|
final_issue_type = result.params.get("issue_type", issue_type)
|
||||||
|
final_description = result.params.get("description", description)
|
||||||
|
final_priority = result.params.get("priority", priority)
|
||||||
|
final_connector_id = result.params.get("connector_id", connector_id)
|
||||||
|
|
||||||
|
if not final_summary or not final_summary.strip():
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Jira connector is invalid.",
|
"message": "Issue summary cannot be empty.",
|
||||||
}
|
}
|
||||||
|
if not final_project_key:
|
||||||
|
return {"status": "error", "message": "A project must be selected."}
|
||||||
|
|
||||||
try:
|
from sqlalchemy.future import select
|
||||||
jira_history = JiraHistoryConnector(
|
|
||||||
session=db_session, connector_id=actual_connector_id
|
|
||||||
)
|
|
||||||
jira_client = await jira_history._get_jira_client()
|
|
||||||
api_result = await asyncio.to_thread(
|
|
||||||
jira_client.create_issue,
|
|
||||||
project_key=final_project_key,
|
|
||||||
summary=final_summary,
|
|
||||||
issue_type=final_issue_type,
|
|
||||||
description=final_description,
|
|
||||||
priority=final_priority,
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
if "status code 403" in str(api_err).lower():
|
|
||||||
try:
|
|
||||||
_conn = connector
|
|
||||||
_conn.config = {**_conn.config, "auth_expired": True}
|
|
||||||
flag_modified(_conn, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
|
||||||
"status": "insufficient_permissions",
|
|
||||||
"connector_id": actual_connector_id,
|
|
||||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
|
||||||
raise
|
|
||||||
|
|
||||||
issue_key = api_result.get("key", "")
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
issue_url = (
|
|
||||||
f"{jira_history._base_url}/browse/{issue_key}"
|
|
||||||
if jira_history._base_url and issue_key
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
kb_message_suffix = ""
|
actual_connector_id = final_connector_id
|
||||||
try:
|
if actual_connector_id is None:
|
||||||
from app.services.jira import JiraKBSyncService
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
kb_service = JiraKBSyncService(db_session)
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
kb_result = await kb_service.sync_after_create(
|
SearchSourceConnector.user_id == user_id,
|
||||||
issue_id=issue_key,
|
SearchSourceConnector.connector_type
|
||||||
issue_identifier=issue_key,
|
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
issue_title=final_summary,
|
)
|
||||||
description=final_description,
|
)
|
||||||
state="To Do",
|
connector = result.scalars().first()
|
||||||
connector_id=actual_connector_id,
|
if not connector:
|
||||||
search_space_id=search_space_id,
|
return {
|
||||||
user_id=user_id,
|
"status": "error",
|
||||||
)
|
"message": "No Jira connector found.",
|
||||||
if kb_result["status"] == "success":
|
}
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
actual_connector_id = connector.id
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
result = await db_session.execute(
|
||||||
except Exception as kb_err:
|
select(SearchSourceConnector).filter(
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
SearchSourceConnector.id == actual_connector_id,
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Jira connector is invalid.",
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
try:
|
||||||
"status": "success",
|
jira_history = JiraHistoryConnector(
|
||||||
"issue_key": issue_key,
|
session=db_session, connector_id=actual_connector_id
|
||||||
"issue_url": issue_url,
|
)
|
||||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
jira_client = await jira_history._get_jira_client()
|
||||||
}
|
api_result = await asyncio.to_thread(
|
||||||
|
jira_client.create_issue,
|
||||||
|
project_key=final_project_key,
|
||||||
|
summary=final_summary,
|
||||||
|
issue_type=final_issue_type,
|
||||||
|
description=final_description,
|
||||||
|
priority=final_priority,
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
if "status code 403" in str(api_err).lower():
|
||||||
|
try:
|
||||||
|
_conn = connector
|
||||||
|
_conn.config = {**_conn.config, "auth_expired": True}
|
||||||
|
flag_modified(_conn, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": actual_connector_id,
|
||||||
|
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
issue_key = api_result.get("key", "")
|
||||||
|
issue_url = (
|
||||||
|
f"{jira_history._base_url}/browse/{issue_key}"
|
||||||
|
if jira_history._base_url and issue_key
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.jira import JiraKBSyncService
|
||||||
|
|
||||||
|
kb_service = JiraKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
issue_id=issue_key,
|
||||||
|
issue_identifier=issue_key,
|
||||||
|
issue_title=final_summary,
|
||||||
|
description=final_description,
|
||||||
|
state="To Do",
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"issue_key": issue_key,
|
||||||
|
"issue_url": issue_url,
|
||||||
|
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the delete_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_jira_issue(
|
async def delete_jira_issue(
|
||||||
issue_title_or_key: str,
|
issue_title_or_key: str,
|
||||||
|
|
@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
|
||||||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_deletion_context(
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
search_space_id, user_id, issue_title_or_key
|
context = await metadata_service.get_deletion_context(
|
||||||
)
|
search_space_id, user_id, issue_title_or_key
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": error_msg,
|
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "jira",
|
|
||||||
}
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
issue_data = context["issue"]
|
|
||||||
issue_key = issue_data["issue_id"]
|
|
||||||
document_id = issue_data["document_id"]
|
|
||||||
connector_id_from_context = context.get("account", {}).get("id")
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="jira_issue_deletion",
|
|
||||||
tool_name="delete_jira_issue",
|
|
||||||
params={
|
|
||||||
"issue_key": issue_key,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_issue_key = result.params.get("issue_key", issue_key)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this issue.",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Jira connector is invalid.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
if "error" in context:
|
||||||
jira_history = JiraHistoryConnector(
|
error_msg = context["error"]
|
||||||
session=db_session, connector_id=final_connector_id
|
if context.get("auth_expired"):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error_msg,
|
||||||
|
"connector_id": context.get("connector_id"),
|
||||||
|
"connector_type": "jira",
|
||||||
|
}
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
issue_data = context["issue"]
|
||||||
|
issue_key = issue_data["issue_id"]
|
||||||
|
document_id = issue_data["document_id"]
|
||||||
|
connector_id_from_context = context.get("account", {}).get("id")
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="jira_issue_deletion",
|
||||||
|
tool_name="delete_jira_issue",
|
||||||
|
params={
|
||||||
|
"issue_key": issue_key,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
"delete_from_kb": delete_from_kb,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
jira_client = await jira_history._get_jira_client()
|
|
||||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
if result.rejected:
|
||||||
except Exception as api_err:
|
|
||||||
if "status code 403" in str(api_err).lower():
|
|
||||||
try:
|
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "rejected",
|
||||||
"connector_id": final_connector_id,
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
deleted_from_kb = False
|
final_issue_key = result.params.get("issue_key", issue_key)
|
||||||
if final_delete_from_kb and document_id:
|
final_connector_id = result.params.get(
|
||||||
try:
|
"connector_id", connector_id_from_context
|
||||||
from app.db import Document
|
)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
from sqlalchemy.future import select
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this issue.",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
)
|
||||||
if document:
|
connector = result.scalars().first()
|
||||||
await db_session.delete(document)
|
if not connector:
|
||||||
await db_session.commit()
|
return {
|
||||||
deleted_from_kb = True
|
"status": "error",
|
||||||
except Exception as e:
|
"message": "Selected Jira connector is invalid.",
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
}
|
||||||
await db_session.rollback()
|
|
||||||
|
|
||||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
try:
|
||||||
if deleted_from_kb:
|
jira_history = JiraHistoryConnector(
|
||||||
message += " Also removed from the knowledge base."
|
session=db_session, connector_id=final_connector_id
|
||||||
|
)
|
||||||
|
jira_client = await jira_history._get_jira_client()
|
||||||
|
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||||
|
except Exception as api_err:
|
||||||
|
if "status code 403" in str(api_err).lower():
|
||||||
|
try:
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": final_connector_id,
|
||||||
|
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
return {
|
deleted_from_kb = False
|
||||||
"status": "success",
|
if final_delete_from_kb and document_id:
|
||||||
"issue_key": final_issue_key,
|
try:
|
||||||
"deleted_from_kb": deleted_from_kb,
|
from app.db import Document
|
||||||
"message": message,
|
|
||||||
}
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
|
||||||
|
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||||
|
if deleted_from_kb:
|
||||||
|
message += " Also removed from the knowledge base."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"issue_key": final_issue_key,
|
||||||
|
"deleted_from_kb": deleted_from_kb,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.jira_history import JiraHistoryConnector
|
from app.connectors.jira_history import JiraHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.jira import JiraToolMetadataService
|
from app.services.jira import JiraToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the update_jira_issue tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
search_space_id: Search space ID to find the Jira connector
|
||||||
|
user_id: User ID for fetching user-specific context
|
||||||
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_jira_issue tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_jira_issue(
|
async def update_jira_issue(
|
||||||
issue_title_or_key: str,
|
issue_title_or_key: str,
|
||||||
|
|
@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
|
||||||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = JiraToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = JiraToolMetadataService(db_session)
|
||||||
search_space_id, user_id, issue_title_or_key
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, issue_title_or_key
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": error_msg,
|
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "jira",
|
|
||||||
}
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
issue_data = context["issue"]
|
|
||||||
issue_key = issue_data["issue_id"]
|
|
||||||
document_id = issue_data.get("document_id")
|
|
||||||
connector_id_from_context = context.get("account", {}).get("id")
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="jira_issue_update",
|
|
||||||
tool_name="update_jira_issue",
|
|
||||||
params={
|
|
||||||
"issue_key": issue_key,
|
|
||||||
"document_id": document_id,
|
|
||||||
"new_summary": new_summary,
|
|
||||||
"new_description": new_description,
|
|
||||||
"new_priority": new_priority,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_issue_key = result.params.get("issue_key", issue_key)
|
|
||||||
final_summary = result.params.get("new_summary", new_summary)
|
|
||||||
final_description = result.params.get("new_description", new_description)
|
|
||||||
final_priority = result.params.get("new_priority", new_priority)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_document_id = result.params.get("document_id", document_id)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this issue.",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Jira connector is invalid.",
|
|
||||||
}
|
|
||||||
|
|
||||||
fields: dict[str, Any] = {}
|
if "error" in context:
|
||||||
if final_summary:
|
error_msg = context["error"]
|
||||||
fields["summary"] = final_summary
|
if context.get("auth_expired"):
|
||||||
if final_description is not None:
|
return {
|
||||||
fields["description"] = {
|
"status": "auth_error",
|
||||||
"type": "doc",
|
"message": error_msg,
|
||||||
"version": 1,
|
"connector_id": context.get("connector_id"),
|
||||||
"content": [
|
"connector_type": "jira",
|
||||||
{
|
|
||||||
"type": "paragraph",
|
|
||||||
"content": [{"type": "text", "text": final_description}],
|
|
||||||
}
|
}
|
||||||
],
|
if "not found" in error_msg.lower():
|
||||||
}
|
return {"status": "not_found", "message": error_msg}
|
||||||
if final_priority:
|
return {"status": "error", "message": error_msg}
|
||||||
fields["priority"] = {"name": final_priority}
|
|
||||||
|
|
||||||
if not fields:
|
issue_data = context["issue"]
|
||||||
return {"status": "error", "message": "No changes specified."}
|
issue_key = issue_data["issue_id"]
|
||||||
|
document_id = issue_data.get("document_id")
|
||||||
|
connector_id_from_context = context.get("account", {}).get("id")
|
||||||
|
|
||||||
try:
|
result = request_approval(
|
||||||
jira_history = JiraHistoryConnector(
|
action_type="jira_issue_update",
|
||||||
session=db_session, connector_id=final_connector_id
|
tool_name="update_jira_issue",
|
||||||
|
params={
|
||||||
|
"issue_key": issue_key,
|
||||||
|
"document_id": document_id,
|
||||||
|
"new_summary": new_summary,
|
||||||
|
"new_description": new_description,
|
||||||
|
"new_priority": new_priority,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
jira_client = await jira_history._get_jira_client()
|
|
||||||
await asyncio.to_thread(
|
if result.rejected:
|
||||||
jira_client.update_issue, final_issue_key, fields
|
|
||||||
)
|
|
||||||
except Exception as api_err:
|
|
||||||
if "status code 403" in str(api_err).lower():
|
|
||||||
try:
|
|
||||||
connector.config = {**connector.config, "auth_expired": True}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await db_session.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"status": "insufficient_permissions",
|
"status": "rejected",
|
||||||
"connector_id": final_connector_id,
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
|
||||||
}
|
}
|
||||||
raise
|
|
||||||
|
|
||||||
issue_url = (
|
final_issue_key = result.params.get("issue_key", issue_key)
|
||||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
final_summary = result.params.get("new_summary", new_summary)
|
||||||
if jira_history._base_url and final_issue_key
|
final_description = result.params.get(
|
||||||
else ""
|
"new_description", new_description
|
||||||
)
|
)
|
||||||
|
final_priority = result.params.get("new_priority", new_priority)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
final_document_id = result.params.get("document_id", document_id)
|
||||||
|
|
||||||
kb_message_suffix = ""
|
from sqlalchemy.future import select
|
||||||
if final_document_id:
|
|
||||||
try:
|
|
||||||
from app.services.jira import JiraKBSyncService
|
|
||||||
|
|
||||||
kb_service = JiraKBSyncService(db_session)
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
kb_result = await kb_service.sync_after_update(
|
|
||||||
document_id=final_document_id,
|
if not final_connector_id:
|
||||||
issue_id=final_issue_key,
|
return {
|
||||||
user_id=user_id,
|
"status": "error",
|
||||||
search_space_id=search_space_id,
|
"message": "No connector found for this issue.",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
)
|
||||||
kb_message_suffix = (
|
connector = result.scalars().first()
|
||||||
" Your knowledge base has also been updated."
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Jira connector is invalid.",
|
||||||
|
}
|
||||||
|
|
||||||
|
fields: dict[str, Any] = {}
|
||||||
|
if final_summary:
|
||||||
|
fields["summary"] = final_summary
|
||||||
|
if final_description is not None:
|
||||||
|
fields["description"] = {
|
||||||
|
"type": "doc",
|
||||||
|
"version": 1,
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "paragraph",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": final_description}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if final_priority:
|
||||||
|
fields["priority"] = {"name": final_priority}
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return {"status": "error", "message": "No changes specified."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
jira_history = JiraHistoryConnector(
|
||||||
|
session=db_session, connector_id=final_connector_id
|
||||||
|
)
|
||||||
|
jira_client = await jira_history._get_jira_client()
|
||||||
|
await asyncio.to_thread(
|
||||||
|
jira_client.update_issue, final_issue_key, fields
|
||||||
|
)
|
||||||
|
except Exception as api_err:
|
||||||
|
if "status code 403" in str(api_err).lower():
|
||||||
|
try:
|
||||||
|
connector.config = {
|
||||||
|
**connector.config,
|
||||||
|
"auth_expired": True,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await db_session.commit()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"connector_id": final_connector_id,
|
||||||
|
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||||
|
}
|
||||||
|
raise
|
||||||
|
|
||||||
|
issue_url = (
|
||||||
|
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||||
|
if jira_history._base_url and final_issue_key
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
if final_document_id:
|
||||||
|
try:
|
||||||
|
from app.services.jira import JiraKBSyncService
|
||||||
|
|
||||||
|
kb_service = JiraKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
document_id=final_document_id,
|
||||||
|
issue_id=final_issue_key,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
)
|
)
|
||||||
else:
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = (
|
||||||
|
" The knowledge base will be updated in the next sync."
|
||||||
|
)
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||||
kb_message_suffix = (
|
kb_message_suffix = (
|
||||||
" The knowledge base will be updated in the next sync."
|
" The knowledge base will be updated in the next sync."
|
||||||
)
|
)
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
|
||||||
kb_message_suffix = (
|
|
||||||
" The knowledge base will be updated in the next sync."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"issue_key": final_issue_key,
|
"issue_key": final_issue_key,
|
||||||
"issue_url": issue_url,
|
"issue_url": issue_url,
|
||||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearToolMetadataService
|
from app.services.linear import LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the create_linear_issue tool.
|
||||||
Factory function to create the create_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured create_linear_issue tool
|
Configured create_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_linear_issue(
|
async def create_linear_issue(
|
||||||
|
|
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_linear_issue called: title='{title}'")
|
logger.info(f"create_linear_issue called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
|
||||||
return {"status": "error", "message": context["error"]}
|
|
||||||
|
|
||||||
workspaces = context.get("workspaces", [])
|
|
||||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
|
||||||
logger.warning("All Linear accounts have expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "linear",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
|
||||||
result = request_approval(
|
|
||||||
action_type="linear_issue_creation",
|
|
||||||
tool_name="create_linear_issue",
|
|
||||||
params={
|
|
||||||
"title": title,
|
|
||||||
"description": description,
|
|
||||||
"team_id": None,
|
|
||||||
"state_id": None,
|
|
||||||
"assignee_id": None,
|
|
||||||
"priority": None,
|
|
||||||
"label_ids": [],
|
|
||||||
"connector_id": connector_id,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Linear issue creation rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_title = result.params.get("title", title)
|
|
||||||
final_description = result.params.get("description", description)
|
|
||||||
final_team_id = result.params.get("team_id")
|
|
||||||
final_state_id = result.params.get("state_id")
|
|
||||||
final_assignee_id = result.params.get("assignee_id")
|
|
||||||
final_priority = result.params.get("priority")
|
|
||||||
final_label_ids = result.params.get("label_ids") or []
|
|
||||||
final_connector_id = result.params.get("connector_id", connector_id)
|
|
||||||
|
|
||||||
if not final_title or not final_title.strip():
|
|
||||||
logger.error("Title is empty or contains only whitespace")
|
|
||||||
return {"status": "error", "message": "Issue title cannot be empty."}
|
|
||||||
if not final_team_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "A team must be selected to create an issue.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
actual_connector_id = final_connector_id
|
|
||||||
if actual_connector_id is None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
|
)
|
||||||
|
return {"status": "error", "message": context["error"]}
|
||||||
|
|
||||||
|
workspaces = context.get("workspaces", [])
|
||||||
|
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||||
|
logger.warning("All Linear accounts have expired authentication")
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "linear",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||||
|
result = request_approval(
|
||||||
|
action_type="linear_issue_creation",
|
||||||
|
tool_name="create_linear_issue",
|
||||||
|
params={
|
||||||
|
"title": title,
|
||||||
|
"description": description,
|
||||||
|
"team_id": None,
|
||||||
|
"state_id": None,
|
||||||
|
"assignee_id": None,
|
||||||
|
"priority": None,
|
||||||
|
"label_ids": [],
|
||||||
|
"connector_id": connector_id,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
logger.info("Linear issue creation rejected by user")
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_title = result.params.get("title", title)
|
||||||
|
final_description = result.params.get("description", description)
|
||||||
|
final_team_id = result.params.get("team_id")
|
||||||
|
final_state_id = result.params.get("state_id")
|
||||||
|
final_assignee_id = result.params.get("assignee_id")
|
||||||
|
final_priority = result.params.get("priority")
|
||||||
|
final_label_ids = result.params.get("label_ids") or []
|
||||||
|
final_connector_id = result.params.get("connector_id", connector_id)
|
||||||
|
|
||||||
|
if not final_title or not final_title.strip():
|
||||||
|
logger.error("Title is empty or contains only whitespace")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
"message": "Issue title cannot be empty.",
|
||||||
}
|
}
|
||||||
actual_connector_id = connector.id
|
if not final_team_id:
|
||||||
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == actual_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
"message": "A team must be selected to create an issue.",
|
||||||
}
|
}
|
||||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
|
||||||
|
|
||||||
logger.info(
|
from sqlalchemy.future import select
|
||||||
f"Creating Linear issue with final params: title='{final_title}'"
|
|
||||||
)
|
|
||||||
linear_client = LinearConnector(
|
|
||||||
session=db_session, connector_id=actual_connector_id
|
|
||||||
)
|
|
||||||
result = await linear_client.create_issue(
|
|
||||||
team_id=final_team_id,
|
|
||||||
title=final_title,
|
|
||||||
description=final_description,
|
|
||||||
state_id=final_state_id,
|
|
||||||
assignee_id=final_assignee_id,
|
|
||||||
priority=final_priority,
|
|
||||||
label_ids=final_label_ids if final_label_ids else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("status") == "error":
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
logger.error(f"Failed to create Linear issue: {result.get('message')}")
|
|
||||||
return {"status": "error", "message": result.get("message")}
|
|
||||||
|
|
||||||
logger.info(
|
actual_connector_id = final_connector_id
|
||||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
if actual_connector_id is None:
|
||||||
)
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
kb_message_suffix = ""
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
try:
|
SearchSourceConnector.user_id == user_id,
|
||||||
from app.services.linear import LinearKBSyncService
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
kb_service = LinearKBSyncService(db_session)
|
)
|
||||||
kb_result = await kb_service.sync_after_create(
|
)
|
||||||
issue_id=result.get("id"),
|
connector = result.scalars().first()
|
||||||
issue_identifier=result.get("identifier", ""),
|
if not connector:
|
||||||
issue_title=result.get("title", final_title),
|
return {
|
||||||
issue_url=result.get("url"),
|
"status": "error",
|
||||||
description=final_description,
|
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
||||||
connector_id=actual_connector_id,
|
}
|
||||||
search_space_id=search_space_id,
|
actual_connector_id = connector.id
|
||||||
user_id=user_id,
|
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
||||||
)
|
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
else:
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
result = await db_session.execute(
|
||||||
except Exception as kb_err:
|
select(SearchSourceConnector).filter(
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
SearchSourceConnector.id == actual_connector_id,
|
||||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||||
|
|
||||||
return {
|
logger.info(
|
||||||
"status": "success",
|
f"Creating Linear issue with final params: title='{final_title}'"
|
||||||
"issue_id": result.get("id"),
|
)
|
||||||
"identifier": result.get("identifier"),
|
linear_client = LinearConnector(
|
||||||
"url": result.get("url"),
|
session=db_session, connector_id=actual_connector_id
|
||||||
"message": (result.get("message", "") + kb_message_suffix),
|
)
|
||||||
}
|
result = await linear_client.create_issue(
|
||||||
|
team_id=final_team_id,
|
||||||
|
title=final_title,
|
||||||
|
description=final_description,
|
||||||
|
state_id=final_state_id,
|
||||||
|
assignee_id=final_assignee_id,
|
||||||
|
priority=final_priority,
|
||||||
|
label_ids=final_label_ids if final_label_ids else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("status") == "error":
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create Linear issue: {result.get('message')}"
|
||||||
|
)
|
||||||
|
return {"status": "error", "message": result.get("message")}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.linear import LinearKBSyncService
|
||||||
|
|
||||||
|
kb_service = LinearKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
issue_id=result.get("id"),
|
||||||
|
issue_identifier=result.get("identifier", ""),
|
||||||
|
issue_title=result.get("title", final_title),
|
||||||
|
issue_url=result.get("url"),
|
||||||
|
description=final_description,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"issue_id": result.get("id"),
|
||||||
|
"identifier": result.get("identifier"),
|
||||||
|
"url": result.get("url"),
|
||||||
|
"message": (result.get("message", "") + kb_message_suffix),
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearToolMetadataService
|
from app.services.linear import LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the delete_linear_issue tool.
|
||||||
Factory function to create the delete_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for finding the correct Linear connector
|
user_id: User ID for finding the correct Linear connector
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured delete_linear_issue tool
|
Configured delete_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_linear_issue(
|
async def delete_linear_issue(
|
||||||
|
|
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
|
||||||
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_delete_context(
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
search_space_id, user_id, issue_ref
|
context = await metadata_service.get_delete_context(
|
||||||
)
|
search_space_id, user_id, issue_ref
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": error_msg,
|
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "linear",
|
|
||||||
}
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Issue not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
issue_id = context["issue"]["id"]
|
|
||||||
issue_identifier = context["issue"].get("identifier", "")
|
|
||||||
document_id = context["issue"]["document_id"]
|
|
||||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
|
||||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="linear_issue_deletion",
|
|
||||||
tool_name="delete_linear_issue",
|
|
||||||
params={
|
|
||||||
"issue_id": issue_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Linear issue deletion rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_issue_id = result.params.get("issue_id", issue_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
|
||||||
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if final_connector_id:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
logger.error(
|
error_msg = context["error"]
|
||||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
if context.get("auth_expired"):
|
||||||
|
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": error_msg,
|
||||||
|
"connector_id": context.get("connector_id"),
|
||||||
|
"connector_type": "linear",
|
||||||
|
}
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
|
logger.warning(f"Issue not found: {error_msg}")
|
||||||
|
return {"status": "not_found", "message": error_msg}
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
issue_id = context["issue"]["id"]
|
||||||
|
issue_identifier = context["issue"].get("identifier", "")
|
||||||
|
document_id = context["issue"]["document_id"]
|
||||||
|
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||||
|
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||||
|
)
|
||||||
|
result = request_approval(
|
||||||
|
action_type="linear_issue_deletion",
|
||||||
|
tool_name="delete_linear_issue",
|
||||||
|
params={
|
||||||
|
"issue_id": issue_id,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
"delete_from_kb": delete_from_kb,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
logger.info("Linear issue deletion rejected by user")
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_issue_id = result.params.get("issue_id", issue_id)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||||
|
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
if final_connector_id:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||||
|
else:
|
||||||
|
logger.error("No connector found for this issue")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
"message": "No connector found for this issue.",
|
||||||
}
|
}
|
||||||
actual_connector_id = connector.id
|
|
||||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
|
||||||
else:
|
|
||||||
logger.error("No connector found for this issue")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this issue.",
|
|
||||||
}
|
|
||||||
|
|
||||||
linear_client = LinearConnector(
|
linear_client = LinearConnector(
|
||||||
session=db_session, connector_id=actual_connector_id
|
session=db_session, connector_id=actual_connector_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
deleted_from_kb = False
|
deleted_from_kb = False
|
||||||
if (
|
if (
|
||||||
result.get("status") == "success"
|
result.get("status") == "success"
|
||||||
and final_delete_from_kb
|
and final_delete_from_kb
|
||||||
and document_id
|
and document_id
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from app.db import Document
|
from app.db import Document
|
||||||
|
|
||||||
doc_result = await db_session.execute(
|
doc_result = await db_session.execute(
|
||||||
select(Document).filter(Document.id == document_id)
|
select(Document).filter(Document.id == document_id)
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
document = doc_result.scalars().first()
|
||||||
if document:
|
if document:
|
||||||
await db_session.delete(document)
|
await db_session.delete(document)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
deleted_from_kb = True
|
deleted_from_kb = True
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted document {document_id} from knowledge base"
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
result["warning"] = (
|
||||||
|
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
result["warning"] = (
|
|
||||||
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("status") == "success":
|
if result.get("status") == "success":
|
||||||
result["deleted_from_kb"] = deleted_from_kb
|
result["deleted_from_kb"] = deleted_from_kb
|
||||||
if issue_identifier:
|
if issue_identifier:
|
||||||
result["message"] = (
|
result["message"] = (
|
||||||
f"Issue {issue_identifier} archived successfully."
|
f"Issue {issue_identifier} archived successfully."
|
||||||
)
|
)
|
||||||
if deleted_from_kb:
|
if deleted_from_kb:
|
||||||
result["message"] = (
|
result["message"] = (
|
||||||
f"{result.get('message', '')} Also removed from the knowledge base."
|
f"{result.get('message', '')} Also removed from the knowledge base."
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Factory function to create the update_linear_issue tool.
|
||||||
Factory function to create the update_linear_issue tool.
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing the Linear connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Linear connector
|
search_space_id: Search space ID to find the Linear connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured update_linear_issue tool
|
Configured update_linear_issue tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_linear_issue(
|
async def update_linear_issue(
|
||||||
|
|
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
|
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Linear tool not properly configured - missing required parameters"
|
"Linear tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = LinearToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = LinearToolMetadataService(db_session)
|
||||||
search_space_id, user_id, issue_ref
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, issue_ref
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
if context.get("auth_expired"):
|
|
||||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": error_msg,
|
|
||||||
"connector_id": context.get("connector_id"),
|
|
||||||
"connector_type": "linear",
|
|
||||||
}
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Issue not found: {error_msg}")
|
|
||||||
return {"status": "not_found", "message": error_msg}
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
|
||||||
return {"status": "error", "message": error_msg}
|
|
||||||
|
|
||||||
issue_id = context["issue"]["id"]
|
|
||||||
document_id = context["issue"]["document_id"]
|
|
||||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
|
||||||
|
|
||||||
team = context.get("team", {})
|
|
||||||
new_state_id = _resolve_state(team, new_state_name)
|
|
||||||
new_assignee_id = _resolve_assignee(team, new_assignee_email)
|
|
||||||
new_label_ids = _resolve_labels(team, new_label_names)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="linear_issue_update",
|
|
||||||
tool_name="update_linear_issue",
|
|
||||||
params={
|
|
||||||
"issue_id": issue_id,
|
|
||||||
"document_id": document_id,
|
|
||||||
"new_title": new_title,
|
|
||||||
"new_description": new_description,
|
|
||||||
"new_state_id": new_state_id,
|
|
||||||
"new_assignee_id": new_assignee_id,
|
|
||||||
"new_priority": new_priority,
|
|
||||||
"new_label_ids": new_label_ids,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Linear issue update rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_issue_id = result.params.get("issue_id", issue_id)
|
|
||||||
final_document_id = result.params.get("document_id", document_id)
|
|
||||||
final_new_title = result.params.get("new_title", new_title)
|
|
||||||
final_new_description = result.params.get(
|
|
||||||
"new_description", new_description
|
|
||||||
)
|
|
||||||
final_new_state_id = result.params.get("new_state_id", new_state_id)
|
|
||||||
final_new_assignee_id = result.params.get(
|
|
||||||
"new_assignee_id", new_assignee_id
|
|
||||||
)
|
|
||||||
final_new_priority = result.params.get("new_priority", new_priority)
|
|
||||||
final_new_label_ids: list[str] | None = result.params.get(
|
|
||||||
"new_label_ids", new_label_ids
|
|
||||||
)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
|
|
||||||
if not final_connector_id:
|
|
||||||
logger.error("No connector found for this issue")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this issue.",
|
|
||||||
}
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
|
||||||
}
|
|
||||||
logger.info(f"Validated Linear connector: id={final_connector_id}")
|
|
||||||
|
|
||||||
logger.info(
|
if "error" in context:
|
||||||
f"Updating Linear issue with final params: issue_id={final_issue_id}"
|
error_msg = context["error"]
|
||||||
)
|
if context.get("auth_expired"):
|
||||||
linear_client = LinearConnector(
|
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||||
session=db_session, connector_id=final_connector_id
|
return {
|
||||||
)
|
"status": "auth_error",
|
||||||
updated_issue = await linear_client.update_issue(
|
"message": error_msg,
|
||||||
issue_id=final_issue_id,
|
"connector_id": context.get("connector_id"),
|
||||||
title=final_new_title,
|
"connector_type": "linear",
|
||||||
description=final_new_description,
|
}
|
||||||
state_id=final_new_state_id,
|
if "not found" in error_msg.lower():
|
||||||
assignee_id=final_new_assignee_id,
|
logger.warning(f"Issue not found: {error_msg}")
|
||||||
priority=final_new_priority,
|
return {"status": "not_found", "message": error_msg}
|
||||||
label_ids=final_new_label_ids,
|
else:
|
||||||
)
|
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||||
|
return {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
if updated_issue.get("status") == "error":
|
issue_id = context["issue"]["id"]
|
||||||
logger.error(
|
document_id = context["issue"]["document_id"]
|
||||||
f"Failed to update Linear issue: {updated_issue.get('message')}"
|
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": updated_issue.get("message"),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
team = context.get("team", {})
|
||||||
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
|
new_state_id = _resolve_state(team, new_state_name)
|
||||||
)
|
new_assignee_id = _resolve_assignee(team, new_assignee_email)
|
||||||
|
new_label_ids = _resolve_labels(team, new_label_names)
|
||||||
|
|
||||||
if final_document_id is not None:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updating knowledge base for document {final_document_id}..."
|
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
|
||||||
)
|
)
|
||||||
kb_service = LinearKBSyncService(db_session)
|
result = request_approval(
|
||||||
kb_result = await kb_service.sync_after_update(
|
action_type="linear_issue_update",
|
||||||
document_id=final_document_id,
|
tool_name="update_linear_issue",
|
||||||
issue_id=final_issue_id,
|
params={
|
||||||
user_id=user_id,
|
"issue_id": issue_id,
|
||||||
search_space_id=search_space_id,
|
"document_id": document_id,
|
||||||
|
"new_title": new_title,
|
||||||
|
"new_description": new_description,
|
||||||
|
"new_state_id": new_state_id,
|
||||||
|
"new_assignee_id": new_assignee_id,
|
||||||
|
"new_priority": new_priority,
|
||||||
|
"new_label_ids": new_label_ids,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
|
||||||
logger.info(
|
|
||||||
f"Knowledge base successfully updated for issue {final_issue_id}"
|
|
||||||
)
|
|
||||||
kb_message = " Your knowledge base has also been updated."
|
|
||||||
elif kb_result["status"] == "not_indexed":
|
|
||||||
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
|
|
||||||
)
|
|
||||||
kb_message = " Your knowledge base will be updated in the next scheduled sync."
|
|
||||||
else:
|
|
||||||
kb_message = ""
|
|
||||||
|
|
||||||
identifier = updated_issue.get("identifier")
|
if result.rejected:
|
||||||
default_msg = f"Issue {identifier} updated successfully."
|
logger.info("Linear issue update rejected by user")
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "rejected",
|
||||||
"identifier": identifier,
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"url": updated_issue.get("url"),
|
}
|
||||||
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
|
|
||||||
}
|
final_issue_id = result.params.get("issue_id", issue_id)
|
||||||
|
final_document_id = result.params.get("document_id", document_id)
|
||||||
|
final_new_title = result.params.get("new_title", new_title)
|
||||||
|
final_new_description = result.params.get(
|
||||||
|
"new_description", new_description
|
||||||
|
)
|
||||||
|
final_new_state_id = result.params.get("new_state_id", new_state_id)
|
||||||
|
final_new_assignee_id = result.params.get(
|
||||||
|
"new_assignee_id", new_assignee_id
|
||||||
|
)
|
||||||
|
final_new_priority = result.params.get("new_priority", new_priority)
|
||||||
|
final_new_label_ids: list[str] | None = result.params.get(
|
||||||
|
"new_label_ids", new_label_ids
|
||||||
|
)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
|
||||||
|
if not final_connector_id:
|
||||||
|
logger.error("No connector found for this issue")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this issue.",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
logger.info(f"Validated Linear connector: id={final_connector_id}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updating Linear issue with final params: issue_id={final_issue_id}"
|
||||||
|
)
|
||||||
|
linear_client = LinearConnector(
|
||||||
|
session=db_session, connector_id=final_connector_id
|
||||||
|
)
|
||||||
|
updated_issue = await linear_client.update_issue(
|
||||||
|
issue_id=final_issue_id,
|
||||||
|
title=final_new_title,
|
||||||
|
description=final_new_description,
|
||||||
|
state_id=final_new_state_id,
|
||||||
|
assignee_id=final_new_assignee_id,
|
||||||
|
priority=final_new_priority,
|
||||||
|
label_ids=final_new_label_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_issue.get("status") == "error":
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update Linear issue: {updated_issue.get('message')}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": updated_issue.get("message"),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_document_id is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Updating knowledge base for document {final_document_id}..."
|
||||||
|
)
|
||||||
|
kb_service = LinearKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
document_id=final_document_id,
|
||||||
|
issue_id=final_issue_id,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
logger.info(
|
||||||
|
f"Knowledge base successfully updated for issue {final_issue_id}"
|
||||||
|
)
|
||||||
|
kb_message = " Your knowledge base has also been updated."
|
||||||
|
elif kb_result["status"] == "not_indexed":
|
||||||
|
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
|
||||||
|
)
|
||||||
|
kb_message = " Your knowledge base will be updated in the next scheduled sync."
|
||||||
|
else:
|
||||||
|
kb_message = ""
|
||||||
|
|
||||||
|
identifier = updated_issue.get("identifier")
|
||||||
|
default_msg = f"Issue {identifier} updated successfully."
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"identifier": identifier,
|
||||||
|
"url": updated_issue.get("url"),
|
||||||
|
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_luma_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_luma_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_luma_event(
|
async def create_luma_event(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -40,83 +58,86 @@ def create_create_luma_event_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_luma_connector(
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
result = request_approval(
|
result = request_approval(
|
||||||
action_type="luma_create_event",
|
action_type="luma_create_event",
|
||||||
tool_name="create_luma_event",
|
tool_name="create_luma_event",
|
||||||
params={
|
params={
|
||||||
"name": name,
|
"name": name,
|
||||||
"start_at": start_at,
|
"start_at": start_at,
|
||||||
"end_at": end_at,
|
"end_at": end_at,
|
||||||
"description": description,
|
"description": description,
|
||||||
"timezone": timezone,
|
"timezone": timezone,
|
||||||
},
|
},
|
||||||
context={"connector_id": connector.id},
|
context={"connector_id": connector.id},
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Event was not created.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_name = result.params.get("name", name)
|
|
||||||
final_start = result.params.get("start_at", start_at)
|
|
||||||
final_end = result.params.get("end_at", end_at)
|
|
||||||
final_desc = result.params.get("description", description)
|
|
||||||
final_tz = result.params.get("timezone", timezone)
|
|
||||||
|
|
||||||
api_key = get_api_key(connector)
|
|
||||||
headers = luma_headers(api_key)
|
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
|
||||||
"name": final_name,
|
|
||||||
"start_at": final_start,
|
|
||||||
"end_at": final_end,
|
|
||||||
"timezone": final_tz,
|
|
||||||
}
|
|
||||||
if final_desc:
|
|
||||||
body["description_md"] = final_desc
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{LUMA_API}/event/create",
|
|
||||||
headers=headers,
|
|
||||||
json=body,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if result.rejected:
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "rejected",
|
||||||
"message": "Luma API key is invalid.",
|
"message": "User declined. Event was not created.",
|
||||||
"connector_type": "luma",
|
}
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Luma Plus subscription required to create events via API.",
|
|
||||||
}
|
|
||||||
if resp.status_code not in (200, 201):
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
|
|
||||||
}
|
|
||||||
|
|
||||||
data = resp.json()
|
final_name = result.params.get("name", name)
|
||||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
final_start = result.params.get("start_at", start_at)
|
||||||
|
final_end = result.params.get("end_at", end_at)
|
||||||
|
final_desc = result.params.get("description", description)
|
||||||
|
final_tz = result.params.get("timezone", timezone)
|
||||||
|
|
||||||
return {
|
api_key = get_api_key(connector)
|
||||||
"status": "success",
|
headers = luma_headers(api_key)
|
||||||
"event_id": event_id,
|
|
||||||
"message": f"Event '{final_name}' created on Luma.",
|
body: dict[str, Any] = {
|
||||||
}
|
"name": final_name,
|
||||||
|
"start_at": final_start,
|
||||||
|
"end_at": final_end,
|
||||||
|
"timezone": final_tz,
|
||||||
|
}
|
||||||
|
if final_desc:
|
||||||
|
body["description_md"] = final_desc
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LUMA_API}/event/create",
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "Luma API key is invalid.",
|
||||||
|
"connector_type": "luma",
|
||||||
|
}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Luma Plus subscription required to create events via API.",
|
||||||
|
}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"event_id": event_id,
|
||||||
|
"message": f"Event '{final_name}' created on Luma.",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_luma_events tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_luma_events tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_luma_events(
|
async def list_luma_events(
|
||||||
max_results: int = 25,
|
max_results: int = 25,
|
||||||
|
|
@ -28,77 +47,80 @@ def create_list_luma_events_tool(
|
||||||
Dictionary with status and a list of events including
|
Dictionary with status and a list of events including
|
||||||
event_id, name, start_at, end_at, location, url.
|
event_id, name, start_at, end_at, location, url.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
max_results = min(max_results, 50)
|
max_results = min(max_results, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_luma_connector(
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
api_key = get_api_key(connector)
|
api_key = get_api_key(connector)
|
||||||
headers = luma_headers(api_key)
|
headers = luma_headers(api_key)
|
||||||
|
|
||||||
all_entries: list[dict] = []
|
all_entries: list[dict] = []
|
||||||
cursor = None
|
cursor = None
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
while len(all_entries) < max_results:
|
while len(all_entries) < max_results:
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"limit": min(100, max_results - len(all_entries))
|
"limit": min(100, max_results - len(all_entries))
|
||||||
}
|
}
|
||||||
if cursor:
|
if cursor:
|
||||||
params["cursor"] = cursor
|
params["cursor"] = cursor
|
||||||
|
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
f"{LUMA_API}/calendar/list-events",
|
f"{LUMA_API}/calendar/list-events",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "Luma API key is invalid.",
|
||||||
|
"connector_type": "luma",
|
||||||
|
}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Luma API error: {resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
entries = data.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
break
|
||||||
|
all_entries.extend(entries)
|
||||||
|
|
||||||
|
next_cursor = data.get("next_cursor")
|
||||||
|
if not next_cursor:
|
||||||
|
break
|
||||||
|
cursor = next_cursor
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for entry in all_entries[:max_results]:
|
||||||
|
ev = entry.get("event", {})
|
||||||
|
geo = ev.get("geo_info", {})
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"event_id": entry.get("api_id"),
|
||||||
|
"name": ev.get("name", "Untitled"),
|
||||||
|
"start_at": ev.get("start_at", ""),
|
||||||
|
"end_at": ev.get("end_at", ""),
|
||||||
|
"timezone": ev.get("timezone", ""),
|
||||||
|
"location": geo.get("name", ""),
|
||||||
|
"url": ev.get("url", ""),
|
||||||
|
"visibility": ev.get("visibility", ""),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
return {"status": "success", "events": events, "total": len(events)}
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Luma API key is invalid.",
|
|
||||||
"connector_type": "luma",
|
|
||||||
}
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Luma API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
data = resp.json()
|
|
||||||
entries = data.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
break
|
|
||||||
all_entries.extend(entries)
|
|
||||||
|
|
||||||
next_cursor = data.get("next_cursor")
|
|
||||||
if not next_cursor:
|
|
||||||
break
|
|
||||||
cursor = next_cursor
|
|
||||||
|
|
||||||
events = []
|
|
||||||
for entry in all_entries[:max_results]:
|
|
||||||
ev = entry.get("event", {})
|
|
||||||
geo = ev.get("geo_info", {})
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"event_id": entry.get("api_id"),
|
|
||||||
"name": ev.get("name", "Untitled"),
|
|
||||||
"start_at": ev.get("start_at", ""),
|
|
||||||
"end_at": ev.get("end_at", ""),
|
|
||||||
"timezone": ev.get("timezone", ""),
|
|
||||||
"location": geo.get("name", ""),
|
|
||||||
"url": ev.get("url", ""),
|
|
||||||
"visibility": ev.get("visibility", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "success", "events": events, "total": len(events)}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_luma_event tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_luma_event tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||||
"""Read detailed information about a specific Luma event.
|
"""Read detailed information about a specific Luma event.
|
||||||
|
|
@ -26,60 +45,63 @@ def create_read_luma_event_tool(
|
||||||
Dictionary with status and full event details including
|
Dictionary with status and full event details including
|
||||||
description, attendees count, meeting URL.
|
description, attendees count, meeting URL.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_luma_connector(
|
||||||
return {"status": "error", "message": "No Luma connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
|
||||||
api_key = get_api_key(connector)
|
|
||||||
headers = luma_headers(api_key)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{LUMA_API}/events/{event_id}",
|
|
||||||
headers=headers,
|
|
||||||
)
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Luma connector found."}
|
||||||
|
|
||||||
if resp.status_code == 401:
|
api_key = get_api_key(connector)
|
||||||
return {
|
headers = luma_headers(api_key)
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Luma API key is invalid.",
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
"connector_type": "luma",
|
resp = await client.get(
|
||||||
}
|
f"{LUMA_API}/events/{event_id}",
|
||||||
if resp.status_code == 404:
|
headers=headers,
|
||||||
return {
|
)
|
||||||
"status": "not_found",
|
|
||||||
"message": f"Event '{event_id}' not found.",
|
if resp.status_code == 401:
|
||||||
}
|
return {
|
||||||
if resp.status_code != 200:
|
"status": "auth_error",
|
||||||
return {
|
"message": "Luma API key is invalid.",
|
||||||
"status": "error",
|
"connector_type": "luma",
|
||||||
"message": f"Luma API error: {resp.status_code}",
|
}
|
||||||
|
if resp.status_code == 404:
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"message": f"Event '{event_id}' not found.",
|
||||||
|
}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Luma API error: {resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
ev = data.get("event", data)
|
||||||
|
geo = ev.get("geo_info", {})
|
||||||
|
|
||||||
|
event_detail = {
|
||||||
|
"event_id": event_id,
|
||||||
|
"name": ev.get("name", ""),
|
||||||
|
"description": ev.get("description", ""),
|
||||||
|
"start_at": ev.get("start_at", ""),
|
||||||
|
"end_at": ev.get("end_at", ""),
|
||||||
|
"timezone": ev.get("timezone", ""),
|
||||||
|
"location_name": geo.get("name", ""),
|
||||||
|
"address": geo.get("address", ""),
|
||||||
|
"url": ev.get("url", ""),
|
||||||
|
"meeting_url": ev.get("meeting_url", ""),
|
||||||
|
"visibility": ev.get("visibility", ""),
|
||||||
|
"cover_url": ev.get("cover_url", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
data = resp.json()
|
return {"status": "success", "event": event_detail}
|
||||||
ev = data.get("event", data)
|
|
||||||
geo = ev.get("geo_info", {})
|
|
||||||
|
|
||||||
event_detail = {
|
|
||||||
"event_id": event_id,
|
|
||||||
"name": ev.get("name", ""),
|
|
||||||
"description": ev.get("description", ""),
|
|
||||||
"start_at": ev.get("start_at", ""),
|
|
||||||
"end_at": ev.get("end_at", ""),
|
|
||||||
"timezone": ev.get("timezone", ""),
|
|
||||||
"location_name": geo.get("name", ""),
|
|
||||||
"address": geo.get("address", ""),
|
|
||||||
"url": ev.get("url", ""),
|
|
||||||
"meeting_url": ev.get("meeting_url", ""),
|
|
||||||
"visibility": ev.get("visibility", ""),
|
|
||||||
"cover_url": ev.get("cover_url", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"status": "success", "event": event_detail}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion import NotionToolMetadataService
|
from app.services.notion import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the create_notion_page tool.
|
Factory function to create the create_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||||
|
cache: the compiled graph (and therefore this closure) is reused
|
||||||
|
across HTTP requests, so capturing a per-request session here would
|
||||||
|
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||||
|
keep the request's outer transaction free of long-running Notion API
|
||||||
|
blocking.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured create_notion_page tool
|
Configured create_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_notion_page(
|
async def create_notion_page(
|
||||||
|
|
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_notion_page called: title='{title}'")
|
logger.info(f"create_notion_page called: title='{title}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -77,154 +88,157 @@ def create_create_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_creation_context(
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
search_space_id, user_id
|
context = await metadata_service.get_creation_context(
|
||||||
)
|
search_space_id, user_id
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": context["error"],
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts = context.get("accounts", [])
|
|
||||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
|
||||||
logger.warning("All Notion accounts have expired authentication")
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
"connector_type": "notion",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
|
||||||
result = request_approval(
|
|
||||||
action_type="notion_page_creation",
|
|
||||||
tool_name="create_notion_page",
|
|
||||||
params={
|
|
||||||
"title": title,
|
|
||||||
"content": content,
|
|
||||||
"parent_page_id": None,
|
|
||||||
"connector_id": connector_id,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Notion page creation rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_title = result.params.get("title", title)
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_parent_page_id = result.params.get("parent_page_id")
|
|
||||||
final_connector_id = result.params.get("connector_id", connector_id)
|
|
||||||
|
|
||||||
if not final_title or not final_title.strip():
|
|
||||||
logger.error("Title is empty or contains only whitespace")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Page title cannot be empty. Please provide a valid title.",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Creating Notion page with final params: title='{final_title}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
actual_connector_id = final_connector_id
|
|
||||||
if actual_connector_id is None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
|
||||||
|
|
||||||
if not connector:
|
if "error" in context:
|
||||||
logger.warning(
|
|
||||||
f"No Notion connector found for search_space_id={search_space_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
logger.info(f"Found Notion connector: id={actual_connector_id}")
|
|
||||||
else:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == actual_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
|
|
||||||
if not connector:
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
|
f"Failed to fetch creation context: {context['error']}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
"message": context["error"],
|
||||||
}
|
}
|
||||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
|
||||||
|
|
||||||
notion_connector = NotionHistoryConnector(
|
accounts = context.get("accounts", [])
|
||||||
session=db_session,
|
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||||
connector_id=actual_connector_id,
|
logger.warning("All Notion accounts have expired authentication")
|
||||||
)
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "notion",
|
||||||
|
}
|
||||||
|
|
||||||
result = await notion_connector.create_page(
|
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||||
title=final_title,
|
result = request_approval(
|
||||||
content=final_content,
|
action_type="notion_page_creation",
|
||||||
parent_page_id=final_parent_page_id,
|
tool_name="create_notion_page",
|
||||||
)
|
params={
|
||||||
logger.info(
|
"title": title,
|
||||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
"content": content,
|
||||||
)
|
"parent_page_id": None,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
if result.get("status") == "success":
|
if result.rejected:
|
||||||
kb_message_suffix = ""
|
logger.info("Notion page creation rejected by user")
|
||||||
try:
|
return {
|
||||||
from app.services.notion import NotionKBSyncService
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
kb_service = NotionKBSyncService(db_session)
|
final_title = result.params.get("title", title)
|
||||||
kb_result = await kb_service.sync_after_create(
|
final_content = result.params.get("content", content)
|
||||||
page_id=result.get("page_id"),
|
final_parent_page_id = result.params.get("parent_page_id")
|
||||||
page_title=result.get("title", final_title),
|
final_connector_id = result.params.get("connector_id", connector_id)
|
||||||
page_url=result.get("url"),
|
|
||||||
content=final_content,
|
if not final_title or not final_title.strip():
|
||||||
connector_id=actual_connector_id,
|
logger.error("Title is empty or contains only whitespace")
|
||||||
search_space_id=search_space_id,
|
return {
|
||||||
user_id=user_id,
|
"status": "error",
|
||||||
)
|
"message": "Page title cannot be empty. Please provide a valid title.",
|
||||||
if kb_result["status"] == "success":
|
}
|
||||||
kb_message_suffix = (
|
|
||||||
" Your knowledge base has also been updated."
|
logger.info(
|
||||||
|
f"Creating Notion page with final params: title='{final_title}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
actual_connector_id = final_connector_id
|
||||||
|
if actual_connector_id is None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
logger.warning(
|
||||||
|
f"No Notion connector found for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
logger.info(f"Found Notion connector: id={actual_connector_id}")
|
||||||
|
else:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == actual_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||||
|
}
|
||||||
|
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||||
|
|
||||||
|
notion_connector = NotionHistoryConnector(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await notion_connector.create_page(
|
||||||
|
title=final_title,
|
||||||
|
content=final_content,
|
||||||
|
parent_page_id=final_parent_page_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("status") == "success":
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.notion import NotionKBSyncService
|
||||||
|
|
||||||
|
kb_service = NotionKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
page_id=result.get("page_id"),
|
||||||
|
page_title=result.get("title", final_title),
|
||||||
|
page_url=result.get("url"),
|
||||||
|
content=final_content,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
result["message"] = result.get("message", "") + kb_message_suffix
|
result["message"] = result.get("message", "") + kb_message_suffix
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the delete_notion_page tool.
|
Factory function to create the delete_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for finding the correct Notion connector
|
user_id: User ID for finding the correct Notion connector
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured delete_notion_page tool
|
Configured delete_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_notion_page(
|
async def delete_notion_page(
|
||||||
|
|
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
|
||||||
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
|
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get page context (page_id, account, title) from indexed data
|
async with async_session_maker() as db_session:
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
# Get page context (page_id, account, title) from indexed data
|
||||||
context = await metadata_service.get_delete_context(
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
search_space_id, user_id, page_title
|
context = await metadata_service.get_delete_context(
|
||||||
)
|
search_space_id, user_id, page_title
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
# Check if it's a "not found" error (softer handling for LLM)
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Page not found: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "not_found",
|
|
||||||
"message": error_msg,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": error_msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Notion account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
page_id = context.get("page_id")
|
if "error" in context:
|
||||||
connector_id_from_context = account.get("id")
|
error_msg = context["error"]
|
||||||
document_id = context.get("document_id")
|
# Check if it's a "not found" error (softer handling for LLM)
|
||||||
|
if "not found" in error_msg.lower():
|
||||||
logger.info(
|
logger.warning(f"Page not found: {error_msg}")
|
||||||
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
|
return {
|
||||||
)
|
"status": "not_found",
|
||||||
|
"message": error_msg,
|
||||||
result = request_approval(
|
}
|
||||||
action_type="notion_page_deletion",
|
|
||||||
tool_name="delete_notion_page",
|
|
||||||
params={
|
|
||||||
"page_id": page_id,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Notion page deletion rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_page_id = result.params.get("page_id", page_id)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
# Validate the connector
|
|
||||||
if final_connector_id:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
|
|
||||||
if not connector:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
|
||||||
else:
|
|
||||||
logger.error("No connector found for this page")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this page.",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create connector instance
|
|
||||||
notion_connector = NotionHistoryConnector(
|
|
||||||
session=db_session,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete the page from Notion
|
|
||||||
result = await notion_connector.delete_page(page_id=final_page_id)
|
|
||||||
logger.info(
|
|
||||||
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If deletion was successful and user wants to delete from KB
|
|
||||||
deleted_from_kb = False
|
|
||||||
if (
|
|
||||||
result.get("status") == "success"
|
|
||||||
and final_delete_from_kb
|
|
||||||
and document_id
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import Document
|
|
||||||
|
|
||||||
# Get the document
|
|
||||||
doc_result = await db_session.execute(
|
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
)
|
|
||||||
document = doc_result.scalars().first()
|
|
||||||
|
|
||||||
if document:
|
|
||||||
await db_session.delete(document)
|
|
||||||
await db_session.commit()
|
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||||
except Exception as e:
|
return {
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
"status": "error",
|
||||||
await db_session.rollback()
|
"message": error_msg,
|
||||||
result["warning"] = (
|
}
|
||||||
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update result with KB deletion status
|
account = context.get("account", {})
|
||||||
if result.get("status") == "success":
|
if account.get("auth_expired"):
|
||||||
result["deleted_from_kb"] = deleted_from_kb
|
logger.warning(
|
||||||
if deleted_from_kb:
|
"Notion account %s has expired authentication",
|
||||||
result["message"] = (
|
account.get("id"),
|
||||||
f"{result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
page_id = context.get("page_id")
|
||||||
|
connector_id_from_context = account.get("id")
|
||||||
|
document_id = context.get("document_id")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="notion_page_deletion",
|
||||||
|
tool_name="delete_notion_page",
|
||||||
|
params={
|
||||||
|
"page_id": page_id,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
"delete_from_kb": delete_from_kb,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
logger.info("Notion page deletion rejected by user")
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_page_id = result.params.get("page_id", page_id)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
# Validate the connector
|
||||||
|
if final_connector_id:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||||
|
else:
|
||||||
|
logger.error("No connector found for this page")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this page.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create connector instance
|
||||||
|
notion_connector = NotionHistoryConnector(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the page from Notion
|
||||||
|
result = await notion_connector.delete_page(page_id=final_page_id)
|
||||||
|
logger.info(
|
||||||
|
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If deletion was successful and user wants to delete from KB
|
||||||
|
deleted_from_kb = False
|
||||||
|
if (
|
||||||
|
result.get("status") == "success"
|
||||||
|
and final_delete_from_kb
|
||||||
|
and document_id
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import Document
|
||||||
|
|
||||||
|
# Get the document
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
document = doc_result.scalars().first()
|
||||||
|
|
||||||
|
if document:
|
||||||
|
await db_session.delete(document)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
result["warning"] = (
|
||||||
|
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update result with KB deletion status
|
||||||
|
if result.get("status") == "success":
|
||||||
|
result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
result["message"] = (
|
||||||
|
f"{result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||||
|
from app.db import async_session_maker
|
||||||
from app.services.notion import NotionToolMetadataService
|
from app.services.notion import NotionToolMetadataService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
|
||||||
"""
|
"""
|
||||||
Factory function to create the update_notion_page tool.
|
Factory function to create the update_notion_page tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache (see
|
||||||
|
``create_create_notion_page_tool`` for the full rationale).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for accessing Notion connector
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
search_space_id: Search space ID to find the Notion connector
|
search_space_id: Search space ID to find the Notion connector
|
||||||
user_id: User ID for fetching user-specific context
|
user_id: User ID for fetching user-specific context
|
||||||
connector_id: Optional specific connector ID (if known)
|
connector_id: Optional specific connector ID (if known)
|
||||||
|
|
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
|
||||||
Returns:
|
Returns:
|
||||||
Configured update_notion_page tool
|
Configured update_notion_page tool
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_notion_page(
|
async def update_notion_page(
|
||||||
|
|
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
|
||||||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Notion tool not properly configured - missing required parameters"
|
"Notion tool not properly configured - missing required parameters"
|
||||||
)
|
)
|
||||||
|
|
@ -88,152 +96,155 @@ def create_update_notion_page_tool(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata_service = NotionToolMetadataService(db_session)
|
async with async_session_maker() as db_session:
|
||||||
context = await metadata_service.get_update_context(
|
metadata_service = NotionToolMetadataService(db_session)
|
||||||
search_space_id, user_id, page_title
|
context = await metadata_service.get_update_context(
|
||||||
)
|
search_space_id, user_id, page_title
|
||||||
|
|
||||||
if "error" in context:
|
|
||||||
error_msg = context["error"]
|
|
||||||
# Check if it's a "not found" error (softer handling for LLM)
|
|
||||||
if "not found" in error_msg.lower():
|
|
||||||
logger.warning(f"Page not found: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "not_found",
|
|
||||||
"message": error_msg,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": error_msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
account = context.get("account", {})
|
|
||||||
if account.get("auth_expired"):
|
|
||||||
logger.warning(
|
|
||||||
"Notion account %s has expired authentication",
|
|
||||||
account.get("id"),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
page_id = context.get("page_id")
|
|
||||||
document_id = context.get("document_id")
|
|
||||||
connector_id_from_context = context.get("account", {}).get("id")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
|
|
||||||
)
|
|
||||||
result = request_approval(
|
|
||||||
action_type="notion_page_update",
|
|
||||||
tool_name="update_notion_page",
|
|
||||||
params={
|
|
||||||
"page_id": page_id,
|
|
||||||
"content": content,
|
|
||||||
"connector_id": connector_id_from_context,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
logger.info("Notion page update rejected by user")
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_page_id = result.params.get("page_id", page_id)
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_connector_id = result.params.get(
|
|
||||||
"connector_id", connector_id_from_context
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
|
|
||||||
if final_connector_id:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connector = result.scalars().first()
|
|
||||||
|
|
||||||
if not connector:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
|
||||||
}
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
|
||||||
else:
|
|
||||||
logger.error("No connector found for this page")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No connector found for this page.",
|
|
||||||
}
|
|
||||||
|
|
||||||
notion_connector = NotionHistoryConnector(
|
|
||||||
session=db_session,
|
|
||||||
connector_id=actual_connector_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await notion_connector.update_page(
|
|
||||||
page_id=final_page_id,
|
|
||||||
content=final_content,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"update_page result: {result.get('status')} - {result.get('message', '')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("status") == "success" and document_id is not None:
|
|
||||||
from app.services.notion import NotionKBSyncService
|
|
||||||
|
|
||||||
logger.info(f"Updating knowledge base for document {document_id}...")
|
|
||||||
kb_service = NotionKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_update(
|
|
||||||
document_id=document_id,
|
|
||||||
appended_content=final_content,
|
|
||||||
user_id=user_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
appended_block_ids=result.get("appended_block_ids"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if kb_result["status"] == "success":
|
if "error" in context:
|
||||||
result["message"] = (
|
error_msg = context["error"]
|
||||||
f"{result['message']}. Your knowledge base has also been updated."
|
# Check if it's a "not found" error (softer handling for LLM)
|
||||||
)
|
if "not found" in error_msg.lower():
|
||||||
logger.info(
|
logger.warning(f"Page not found: {error_msg}")
|
||||||
f"Knowledge base successfully updated for page {final_page_id}"
|
return {
|
||||||
)
|
"status": "not_found",
|
||||||
elif kb_result["status"] == "not_indexed":
|
"message": error_msg,
|
||||||
result["message"] = (
|
}
|
||||||
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
|
else:
|
||||||
)
|
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||||
else:
|
return {
|
||||||
result["message"] = (
|
"status": "error",
|
||||||
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
|
"message": error_msg,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
account = context.get("account", {})
|
||||||
|
if account.get("auth_expired"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"KB update failed for page {final_page_id}: {kb_result['message']}"
|
"Notion account %s has expired authentication",
|
||||||
|
account.get("id"),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
page_id = context.get("page_id")
|
||||||
|
document_id = context.get("document_id")
|
||||||
|
connector_id_from_context = context.get("account", {}).get("id")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
|
||||||
|
)
|
||||||
|
result = request_approval(
|
||||||
|
action_type="notion_page_update",
|
||||||
|
tool_name="update_notion_page",
|
||||||
|
params={
|
||||||
|
"page_id": page_id,
|
||||||
|
"content": content,
|
||||||
|
"connector_id": connector_id_from_context,
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
logger.info("Notion page update rejected by user")
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_page_id = result.params.get("page_id", page_id)
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_connector_id = result.params.get(
|
||||||
|
"connector_id", connector_id_from_context
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
if final_connector_id:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||||
|
}
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||||
|
else:
|
||||||
|
logger.error("No connector found for this page")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "No connector found for this page.",
|
||||||
|
}
|
||||||
|
|
||||||
|
notion_connector = NotionHistoryConnector(
|
||||||
|
session=db_session,
|
||||||
|
connector_id=actual_connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await notion_connector.update_page(
|
||||||
|
page_id=final_page_id,
|
||||||
|
content=final_content,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"update_page result: {result.get('status')} - {result.get('message', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("status") == "success" and document_id is not None:
|
||||||
|
from app.services.notion import NotionKBSyncService
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updating knowledge base for document {document_id}..."
|
||||||
|
)
|
||||||
|
kb_service = NotionKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_update(
|
||||||
|
document_id=document_id,
|
||||||
|
appended_content=final_content,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
appended_block_ids=result.get("appended_block_ids"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
if kb_result["status"] == "success":
|
||||||
|
result["message"] = (
|
||||||
|
f"{result['message']}. Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Knowledge base successfully updated for page {final_page_id}"
|
||||||
|
)
|
||||||
|
elif kb_result["status"] == "not_indexed":
|
||||||
|
result["message"] = (
|
||||||
|
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result["message"] = (
|
||||||
|
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"KB update failed for page {final_page_id}: {kb_result['message']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.connectors.onedrive.client import OneDriveClient
|
from app.connectors.onedrive.client import OneDriveClient
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the create_onedrive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured create_onedrive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_onedrive_file(
|
async def create_onedrive_file(
|
||||||
name: str,
|
name: str,
|
||||||
|
|
@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
|
||||||
"""
|
"""
|
||||||
logger.info(f"create_onedrive_file called: name='{name}'")
|
logger.info(f"create_onedrive_file called: name='{name}'")
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "OneDrive tool not properly configured.",
|
"message": "OneDrive tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connectors = result.scalars().all()
|
|
||||||
|
|
||||||
if not connectors:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts = []
|
|
||||||
for c in connectors:
|
|
||||||
cfg = c.config or {}
|
|
||||||
accounts.append(
|
|
||||||
{
|
|
||||||
"id": c.id,
|
|
||||||
"name": c.name,
|
|
||||||
"user_email": cfg.get("user_email"),
|
|
||||||
"auth_expired": cfg.get("auth_expired", False),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if all(a.get("auth_expired") for a in accounts):
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "All connected OneDrive accounts need re-authentication.",
|
|
||||||
"connector_type": "onedrive",
|
|
||||||
}
|
|
||||||
|
|
||||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
|
||||||
for acc in accounts:
|
|
||||||
cid = acc["id"]
|
|
||||||
if acc.get("auth_expired"):
|
|
||||||
parent_folders[cid] = []
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
client = OneDriveClient(session=db_session, connector_id=cid)
|
|
||||||
items, err = await client.list_children("root")
|
|
||||||
if err:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to list folders for connector %s: %s", cid, err
|
|
||||||
)
|
|
||||||
parent_folders[cid] = []
|
|
||||||
else:
|
|
||||||
parent_folders[cid] = [
|
|
||||||
{"folder_id": item["id"], "name": item["name"]}
|
|
||||||
for item in items
|
|
||||||
if item.get("folder") is not None
|
|
||||||
and item.get("id")
|
|
||||||
and item.get("name")
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Error fetching folders for connector %s", cid, exc_info=True
|
|
||||||
)
|
|
||||||
parent_folders[cid] = []
|
|
||||||
|
|
||||||
context: dict[str, Any] = {
|
|
||||||
"accounts": accounts,
|
|
||||||
"parent_folders": parent_folders,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="onedrive_file_creation",
|
|
||||||
tool_name="create_onedrive_file",
|
|
||||||
params={
|
|
||||||
"name": name,
|
|
||||||
"content": content,
|
|
||||||
"connector_id": None,
|
|
||||||
"parent_folder_id": None,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_name = result.params.get("name", name)
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_connector_id = result.params.get("connector_id")
|
|
||||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
|
||||||
|
|
||||||
if not final_name or not final_name.strip():
|
|
||||||
return {"status": "error", "message": "File name cannot be empty."}
|
|
||||||
|
|
||||||
final_name = _ensure_docx_extension(final_name)
|
|
||||||
|
|
||||||
if final_connector_id is not None:
|
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connector = result.scalars().first()
|
connectors = result.scalars().all()
|
||||||
else:
|
|
||||||
connector = connectors[0]
|
|
||||||
|
|
||||||
if not connector:
|
if not connectors:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected OneDrive connector is invalid.",
|
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts = []
|
||||||
|
for c in connectors:
|
||||||
|
cfg = c.config or {}
|
||||||
|
accounts.append(
|
||||||
|
{
|
||||||
|
"id": c.id,
|
||||||
|
"name": c.name,
|
||||||
|
"user_email": cfg.get("user_email"),
|
||||||
|
"auth_expired": cfg.get("auth_expired", False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if all(a.get("auth_expired") for a in accounts):
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "All connected OneDrive accounts need re-authentication.",
|
||||||
|
"connector_type": "onedrive",
|
||||||
|
}
|
||||||
|
|
||||||
|
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||||
|
for acc in accounts:
|
||||||
|
cid = acc["id"]
|
||||||
|
if acc.get("auth_expired"):
|
||||||
|
parent_folders[cid] = []
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
client = OneDriveClient(session=db_session, connector_id=cid)
|
||||||
|
items, err = await client.list_children("root")
|
||||||
|
if err:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to list folders for connector %s: %s", cid, err
|
||||||
|
)
|
||||||
|
parent_folders[cid] = []
|
||||||
|
else:
|
||||||
|
parent_folders[cid] = [
|
||||||
|
{"folder_id": item["id"], "name": item["name"]}
|
||||||
|
for item in items
|
||||||
|
if item.get("folder") is not None
|
||||||
|
and item.get("id")
|
||||||
|
and item.get("name")
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Error fetching folders for connector %s",
|
||||||
|
cid,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
parent_folders[cid] = []
|
||||||
|
|
||||||
|
context: dict[str, Any] = {
|
||||||
|
"accounts": accounts,
|
||||||
|
"parent_folders": parent_folders,
|
||||||
}
|
}
|
||||||
|
|
||||||
docx_bytes = _markdown_to_docx(final_content or "")
|
result = request_approval(
|
||||||
|
action_type="onedrive_file_creation",
|
||||||
client = OneDriveClient(session=db_session, connector_id=connector.id)
|
tool_name="create_onedrive_file",
|
||||||
created = await client.create_file(
|
params={
|
||||||
name=final_name,
|
"name": name,
|
||||||
parent_id=final_parent_folder_id,
|
"content": content,
|
||||||
content=docx_bytes,
|
"connector_id": None,
|
||||||
mime_type=DOCX_MIME,
|
"parent_folder_id": None,
|
||||||
)
|
},
|
||||||
|
context=context,
|
||||||
logger.info(
|
|
||||||
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
kb_message_suffix = ""
|
|
||||||
try:
|
|
||||||
from app.services.onedrive import OneDriveKBSyncService
|
|
||||||
|
|
||||||
kb_service = OneDriveKBSyncService(db_session)
|
|
||||||
kb_result = await kb_service.sync_after_create(
|
|
||||||
file_id=created.get("id"),
|
|
||||||
file_name=created.get("name", final_name),
|
|
||||||
mime_type=DOCX_MIME,
|
|
||||||
web_url=created.get("webUrl"),
|
|
||||||
content=final_content,
|
|
||||||
connector_id=connector.id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
if kb_result["status"] == "success":
|
|
||||||
kb_message_suffix = " Your knowledge base has also been updated."
|
|
||||||
else:
|
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
|
||||||
except Exception as kb_err:
|
|
||||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
|
||||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
|
||||||
|
|
||||||
return {
|
if result.rejected:
|
||||||
"status": "success",
|
return {
|
||||||
"file_id": created.get("id"),
|
"status": "rejected",
|
||||||
"name": created.get("name"),
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
"web_url": created.get("webUrl"),
|
}
|
||||||
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
|
|
||||||
}
|
final_name = result.params.get("name", name)
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_connector_id = result.params.get("connector_id")
|
||||||
|
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||||
|
|
||||||
|
if not final_name or not final_name.strip():
|
||||||
|
return {"status": "error", "message": "File name cannot be empty."}
|
||||||
|
|
||||||
|
final_name = _ensure_docx_extension(final_name)
|
||||||
|
|
||||||
|
if final_connector_id is not None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
else:
|
||||||
|
connector = connectors[0]
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected OneDrive connector is invalid.",
|
||||||
|
}
|
||||||
|
|
||||||
|
docx_bytes = _markdown_to_docx(final_content or "")
|
||||||
|
|
||||||
|
client = OneDriveClient(session=db_session, connector_id=connector.id)
|
||||||
|
created = await client.create_file(
|
||||||
|
name=final_name,
|
||||||
|
parent_id=final_parent_folder_id,
|
||||||
|
content=docx_bytes,
|
||||||
|
mime_type=DOCX_MIME,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_message_suffix = ""
|
||||||
|
try:
|
||||||
|
from app.services.onedrive import OneDriveKBSyncService
|
||||||
|
|
||||||
|
kb_service = OneDriveKBSyncService(db_session)
|
||||||
|
kb_result = await kb_service.sync_after_create(
|
||||||
|
file_id=created.get("id"),
|
||||||
|
file_name=created.get("name", final_name),
|
||||||
|
mime_type=DOCX_MIME,
|
||||||
|
web_url=created.get("webUrl"),
|
||||||
|
content=final_content,
|
||||||
|
connector_id=connector.id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if kb_result["status"] == "success":
|
||||||
|
kb_message_suffix = (
|
||||||
|
" Your knowledge base has also been updated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
except Exception as kb_err:
|
||||||
|
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||||
|
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": created.get("id"),
|
||||||
|
"name": created.get("name"),
|
||||||
|
"web_url": created.get("webUrl"),
|
||||||
|
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
||||||
DocumentType,
|
DocumentType,
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
|
async_session_maker,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the delete_onedrive_file tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured delete_onedrive_file tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_onedrive_file(
|
async def delete_onedrive_file(
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
|
@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
|
||||||
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "OneDrive tool not properly configured.",
|
"message": "OneDrive tool not properly configured.",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc_result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(Document)
|
|
||||||
.join(
|
|
||||||
SearchSourceConnector,
|
|
||||||
Document.connector_id == SearchSourceConnector.id,
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
and_(
|
|
||||||
Document.search_space_id == search_space_id,
|
|
||||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
|
||||||
func.lower(Document.title) == func.lower(file_name),
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(Document.updated_at.desc().nullslast())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
document = doc_result.scalars().first()
|
|
||||||
|
|
||||||
if not document:
|
|
||||||
doc_result = await db_session.execute(
|
doc_result = await db_session.execute(
|
||||||
select(Document)
|
select(Document)
|
||||||
.join(
|
.join(
|
||||||
|
|
@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
|
||||||
and_(
|
and_(
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||||
func.lower(
|
func.lower(Document.title) == func.lower(file_name),
|
||||||
cast(
|
|
||||||
Document.document_metadata["onedrive_file_name"],
|
|
||||||
String,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
== func.lower(file_name),
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
|
||||||
)
|
)
|
||||||
document = doc_result.scalars().first()
|
document = doc_result.scalars().first()
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
return {
|
doc_result = await db_session.execute(
|
||||||
"status": "not_found",
|
select(Document)
|
||||||
"message": (
|
.join(
|
||||||
f"File '{file_name}' not found in your indexed OneDrive files. "
|
SearchSourceConnector,
|
||||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
Document.connector_id == SearchSourceConnector.id,
|
||||||
"or (3) the file name is different."
|
)
|
||||||
),
|
.filter(
|
||||||
}
|
and_(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
if not document.connector_id:
|
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||||
return {
|
func.lower(
|
||||||
"status": "error",
|
cast(
|
||||||
"message": "Document has no associated connector.",
|
Document.document_metadata[
|
||||||
}
|
"onedrive_file_name"
|
||||||
|
],
|
||||||
meta = document.document_metadata or {}
|
String,
|
||||||
file_id = meta.get("onedrive_file_id")
|
)
|
||||||
document_id = document.id
|
)
|
||||||
|
== func.lower(file_name),
|
||||||
if not file_id:
|
SearchSourceConnector.user_id == user_id,
|
||||||
return {
|
)
|
||||||
"status": "error",
|
)
|
||||||
"message": "File ID is missing. Please re-index the file.",
|
.order_by(Document.updated_at.desc().nullslast())
|
||||||
}
|
.limit(1)
|
||||||
|
|
||||||
conn_result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
and_(
|
|
||||||
SearchSourceConnector.id == document.connector_id,
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
|
||||||
)
|
)
|
||||||
)
|
document = doc_result.scalars().first()
|
||||||
)
|
|
||||||
connector = conn_result.scalars().first()
|
|
||||||
if not connector:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "OneDrive connector not found or access denied.",
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg = connector.config or {}
|
if not document:
|
||||||
if cfg.get("auth_expired"):
|
return {
|
||||||
return {
|
"status": "not_found",
|
||||||
"status": "auth_error",
|
"message": (
|
||||||
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
|
f"File '{file_name}' not found in your indexed OneDrive files. "
|
||||||
"connector_type": "onedrive",
|
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||||
}
|
"or (3) the file name is different."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
context = {
|
if not document.connector_id:
|
||||||
"file": {
|
return {
|
||||||
"file_id": file_id,
|
"status": "error",
|
||||||
"name": file_name,
|
"message": "Document has no associated connector.",
|
||||||
"document_id": document_id,
|
}
|
||||||
"web_url": meta.get("web_url"),
|
|
||||||
},
|
|
||||||
"account": {
|
|
||||||
"id": connector.id,
|
|
||||||
"name": connector.name,
|
|
||||||
"user_email": cfg.get("user_email"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = request_approval(
|
meta = document.document_metadata or {}
|
||||||
action_type="onedrive_file_trash",
|
file_id = meta.get("onedrive_file_id")
|
||||||
tool_name="delete_onedrive_file",
|
document_id = document.id
|
||||||
params={
|
|
||||||
"file_id": file_id,
|
|
||||||
"connector_id": connector.id,
|
|
||||||
"delete_from_kb": delete_from_kb,
|
|
||||||
},
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
if not file_id:
|
||||||
return {
|
return {
|
||||||
"status": "rejected",
|
"status": "error",
|
||||||
"message": "User declined. Do not retry or suggest alternatives.",
|
"message": "File ID is missing. Please re-index the file.",
|
||||||
}
|
}
|
||||||
|
|
||||||
final_file_id = result.params.get("file_id", file_id)
|
conn_result = await db_session.execute(
|
||||||
final_connector_id = result.params.get("connector_id", connector.id)
|
|
||||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
|
||||||
|
|
||||||
if final_connector_id != connector.id:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
and_(
|
and_(
|
||||||
SearchSourceConnector.id == final_connector_id,
|
SearchSourceConnector.id == document.connector_id,
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
SearchSourceConnector.user_id == user_id,
|
SearchSourceConnector.user_id == user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
|
|
@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
validated_connector = result.scalars().first()
|
connector = conn_result.scalars().first()
|
||||||
if not validated_connector:
|
if not connector:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "Selected OneDrive connector is invalid or has been disconnected.",
|
"message": "OneDrive connector not found or access denied.",
|
||||||
}
|
}
|
||||||
actual_connector_id = validated_connector.id
|
|
||||||
else:
|
|
||||||
actual_connector_id = connector.id
|
|
||||||
|
|
||||||
logger.info(
|
cfg = connector.config or {}
|
||||||
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
|
if cfg.get("auth_expired"):
|
||||||
)
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||||
|
"connector_type": "onedrive",
|
||||||
|
}
|
||||||
|
|
||||||
client = OneDriveClient(
|
context = {
|
||||||
session=db_session, connector_id=actual_connector_id
|
"file": {
|
||||||
)
|
"file_id": file_id,
|
||||||
await client.trash_file(final_file_id)
|
"name": file_name,
|
||||||
|
"document_id": document_id,
|
||||||
|
"web_url": meta.get("web_url"),
|
||||||
|
},
|
||||||
|
"account": {
|
||||||
|
"id": connector.id,
|
||||||
|
"name": connector.name,
|
||||||
|
"user_email": cfg.get("user_email"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(
|
result = request_approval(
|
||||||
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
|
action_type="onedrive_file_trash",
|
||||||
)
|
tool_name="delete_onedrive_file",
|
||||||
|
params={
|
||||||
trash_result: dict[str, Any] = {
|
"file_id": file_id,
|
||||||
"status": "success",
|
"connector_id": connector.id,
|
||||||
"file_id": final_file_id,
|
"delete_from_kb": delete_from_kb,
|
||||||
"message": f"Successfully moved '{file_name}' to the recycle bin.",
|
},
|
||||||
}
|
context=context,
|
||||||
|
|
||||||
deleted_from_kb = False
|
|
||||||
if final_delete_from_kb and document_id:
|
|
||||||
try:
|
|
||||||
doc_result = await db_session.execute(
|
|
||||||
select(Document).filter(Document.id == document_id)
|
|
||||||
)
|
|
||||||
doc = doc_result.scalars().first()
|
|
||||||
if doc:
|
|
||||||
await db_session.delete(doc)
|
|
||||||
await db_session.commit()
|
|
||||||
deleted_from_kb = True
|
|
||||||
logger.info(
|
|
||||||
f"Deleted document {document_id} from knowledge base"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Document {document_id} not found in KB")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete document from KB: {e}")
|
|
||||||
await db_session.rollback()
|
|
||||||
trash_result["warning"] = (
|
|
||||||
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
|
||||||
if deleted_from_kb:
|
|
||||||
trash_result["message"] = (
|
|
||||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return trash_result
|
if result.rejected:
|
||||||
|
return {
|
||||||
|
"status": "rejected",
|
||||||
|
"message": "User declined. Do not retry or suggest alternatives.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_file_id = result.params.get("file_id", file_id)
|
||||||
|
final_connector_id = result.params.get("connector_id", connector.id)
|
||||||
|
final_delete_from_kb = result.params.get(
|
||||||
|
"delete_from_kb", delete_from_kb
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_connector_id != connector.id:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
and_(
|
||||||
|
SearchSourceConnector.id == final_connector_id,
|
||||||
|
SearchSourceConnector.search_space_id
|
||||||
|
== search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
validated_connector = result.scalars().first()
|
||||||
|
if not validated_connector:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Selected OneDrive connector is invalid or has been disconnected.",
|
||||||
|
}
|
||||||
|
actual_connector_id = validated_connector.id
|
||||||
|
else:
|
||||||
|
actual_connector_id = connector.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = OneDriveClient(
|
||||||
|
session=db_session, connector_id=actual_connector_id
|
||||||
|
)
|
||||||
|
await client.trash_file(final_file_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result: dict[str, Any] = {
|
||||||
|
"status": "success",
|
||||||
|
"file_id": final_file_id,
|
||||||
|
"message": f"Successfully moved '{file_name}' to the recycle bin.",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_from_kb = False
|
||||||
|
if final_delete_from_kb and document_id:
|
||||||
|
try:
|
||||||
|
doc_result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
doc = doc_result.scalars().first()
|
||||||
|
if doc:
|
||||||
|
await db_session.delete(doc)
|
||||||
|
await db_session.commit()
|
||||||
|
deleted_from_kb = True
|
||||||
|
logger.info(
|
||||||
|
f"Deleted document {document_id} from knowledge base"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Document {document_id} not found in KB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete document from KB: {e}")
|
||||||
|
await db_session.rollback()
|
||||||
|
trash_result["warning"] = (
|
||||||
|
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||||
|
if deleted_from_kb:
|
||||||
|
trash_result["message"] = (
|
||||||
|
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return trash_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -824,13 +824,22 @@ async def build_tools_async(
|
||||||
"""Async version of build_tools that also loads MCP tools from database.
|
"""Async version of build_tools that also loads MCP tools from database.
|
||||||
|
|
||||||
Design Note:
|
Design Note:
|
||||||
This function exists because MCP tools require database queries to load user configs,
|
This function exists because MCP tools require database queries to load
|
||||||
while built-in tools are created synchronously from static code.
|
user configs, while built-in tools are created synchronously from static
|
||||||
|
code.
|
||||||
|
|
||||||
Alternative: We could make build_tools() itself async and always query the database,
|
Alternative: We could make build_tools() itself async and always query
|
||||||
but that would force async everywhere even when only using built-in tools. The current
|
the database, but that would force async everywhere even when only using
|
||||||
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
built-in tools. The current design keeps the simple case (static tools
|
||||||
database-loaded tools through this async wrapper.
|
only) synchronous while supporting dynamic database-loaded tools through
|
||||||
|
this async wrapper.
|
||||||
|
|
||||||
|
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
|
||||||
|
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
|
||||||
|
the event loop) are kicked off concurrently. Cold-path savings are
|
||||||
|
bounded by the slower of the two — typically MCP at ~200ms-1.7s —
|
||||||
|
so the parallelization recovers the ~50-200ms previously spent
|
||||||
|
serially on built-in construction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dependencies: Dict containing all possible dependencies
|
dependencies: Dict containing all possible dependencies
|
||||||
|
|
@ -843,33 +852,70 @@ async def build_tools_async(
|
||||||
List of configured tool instances ready for the agent, including MCP tools.
|
List of configured tool instances ready for the agent, including MCP tools.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
_perf_log = logging.getLogger("surfsense.perf")
|
_perf_log = logging.getLogger("surfsense.perf")
|
||||||
_perf_log.setLevel(logging.DEBUG)
|
_perf_log.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
can_load_mcp = (
|
||||||
|
include_mcp_tools
|
||||||
|
and "db_session" in dependencies
|
||||||
|
and "search_space_id" in dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
|
||||||
|
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
|
||||||
|
# function over its inputs — safe to thread-shift.
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
builtin_task = asyncio.create_task(
|
||||||
|
asyncio.to_thread(
|
||||||
|
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_task: asyncio.Task | None = None
|
||||||
|
if can_load_mcp:
|
||||||
|
mcp_task = asyncio.create_task(
|
||||||
|
load_mcp_tools(
|
||||||
|
dependencies["db_session"],
|
||||||
|
dependencies["search_space_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Surface failures from each task independently so a flaky MCP
|
||||||
|
# endpoint never poisons built-in tool registration. ``return_exceptions``
|
||||||
|
# gives us per-task exceptions instead of dropping the second result
|
||||||
|
# when the first raises.
|
||||||
|
if mcp_task is not None:
|
||||||
|
builtin_result, mcp_result = await asyncio.gather(
|
||||||
|
builtin_task, mcp_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
builtin_result = await builtin_task
|
||||||
|
mcp_result = None
|
||||||
|
|
||||||
|
if isinstance(builtin_result, BaseException):
|
||||||
|
raise builtin_result # built-in registration failure is non-recoverable
|
||||||
|
tools: list[BaseTool] = builtin_result
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
"[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
len(tools),
|
len(tools),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load MCP tools if requested and dependencies are available
|
if mcp_task is not None:
|
||||||
if (
|
if isinstance(mcp_result, BaseException):
|
||||||
include_mcp_tools
|
# ``return_exceptions=True`` captures the exception out-of-band,
|
||||||
and "db_session" in dependencies
|
# so ``sys.exc_info()`` is empty here. Pass the captured
|
||||||
and "search_space_id" in dependencies
|
# exception via ``exc_info=`` to get a real traceback.
|
||||||
):
|
logging.error(
|
||||||
try:
|
"Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
|
||||||
_t0 = time.perf_counter()
|
|
||||||
mcp_tools = await load_mcp_tools(
|
|
||||||
dependencies["db_session"],
|
|
||||||
dependencies["search_space_id"],
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
mcp_tools = mcp_result or []
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
"[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
len(mcp_tools),
|
len(mcp_tools),
|
||||||
)
|
)
|
||||||
|
|
@ -879,8 +925,6 @@ async def build_tools_async(
|
||||||
len(mcp_tools),
|
len(mcp_tools),
|
||||||
[t.name for t in mcp_tools],
|
[t.name for t in mcp_tools],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Failed to load MCP tools: %s", e)
|
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"Total tools for agent: %d — %s",
|
"Total tools for agent: %d — %s",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
|
||||||
from app.utils.document_converters import embed_text
|
from app.utils.document_converters import embed_text
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
||||||
"""
|
"""
|
||||||
Factory function to create the search_surfsense_docs tool.
|
Factory function to create the search_surfsense_docs tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session: Database session for executing queries
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured tool function for searching Surfsense documentation
|
A configured tool function for searching Surfsense documentation
|
||||||
"""
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
||||||
|
|
@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
||||||
Returns:
|
Returns:
|
||||||
Relevant documentation content formatted with chunk IDs for citations
|
Relevant documentation content formatted with chunk IDs for citations
|
||||||
"""
|
"""
|
||||||
return await search_surfsense_docs_async(
|
async with async_session_maker() as db_session:
|
||||||
query=query,
|
return await search_surfsense_docs_async(
|
||||||
db_session=db_session,
|
query=query,
|
||||||
top_k=top_k,
|
db_session=db_session,
|
||||||
)
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
return search_surfsense_docs
|
return search_surfsense_docs
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the list_teams_channels tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured list_teams_channels tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_teams_channels() -> dict[str, Any]:
|
async def list_teams_channels() -> dict[str, Any]:
|
||||||
"""List all Microsoft Teams and their channels the user has access to.
|
"""List all Microsoft Teams and their channels the user has access to.
|
||||||
|
|
@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
|
||||||
Dictionary with status and a list of teams, each containing
|
Dictionary with status and a list of teams, each containing
|
||||||
team_id, team_name, and a list of channels (id, name).
|
team_id, team_name, and a list of channels (id, name).
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_teams_connector(
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
|
||||||
token = await get_access_token(db_session, connector)
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
|
||||||
teams_resp = await client.get(
|
|
||||||
f"{GRAPH_API}/me/joinedTeams", headers=headers
|
|
||||||
)
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
if teams_resp.status_code == 401:
|
token = await get_access_token(db_session, connector)
|
||||||
return {
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Teams token expired. Please re-authenticate.",
|
|
||||||
"connector_type": "teams",
|
|
||||||
}
|
|
||||||
if teams_resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Graph API error: {teams_resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
teams_data = teams_resp.json().get("value", [])
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
result_teams = []
|
teams_resp = await client.get(
|
||||||
|
f"{GRAPH_API}/me/joinedTeams", headers=headers
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
|
||||||
for team in teams_data:
|
|
||||||
team_id = team["id"]
|
|
||||||
ch_resp = await client.get(
|
|
||||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
channels = []
|
|
||||||
if ch_resp.status_code == 200:
|
|
||||||
channels = [
|
|
||||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
|
||||||
for ch in ch_resp.json().get("value", [])
|
|
||||||
]
|
|
||||||
result_teams.append(
|
|
||||||
{
|
|
||||||
"team_id": team_id,
|
|
||||||
"team_name": team.get("displayName", ""),
|
|
||||||
"channels": channels,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
if teams_resp.status_code == 401:
|
||||||
"status": "success",
|
return {
|
||||||
"teams": result_teams,
|
"status": "auth_error",
|
||||||
"total_teams": len(result_teams),
|
"message": "Teams token expired. Please re-authenticate.",
|
||||||
}
|
"connector_type": "teams",
|
||||||
|
}
|
||||||
|
if teams_resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Graph API error: {teams_resp.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
teams_data = teams_resp.json().get("value", [])
|
||||||
|
result_teams = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
for team in teams_data:
|
||||||
|
team_id = team["id"]
|
||||||
|
ch_resp = await client.get(
|
||||||
|
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
channels = []
|
||||||
|
if ch_resp.status_code == 200:
|
||||||
|
channels = [
|
||||||
|
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||||
|
for ch in ch_resp.json().get("value", [])
|
||||||
|
]
|
||||||
|
result_teams.append(
|
||||||
|
{
|
||||||
|
"team_id": team_id,
|
||||||
|
"team_name": team.get("displayName", ""),
|
||||||
|
"channels": channels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"teams": result_teams,
|
||||||
|
"total_teams": len(result_teams),
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import httpx
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the read_teams_messages tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured read_teams_messages tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def read_teams_messages(
|
async def read_teams_messages(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
|
|
@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
|
||||||
Dictionary with status and a list of messages including
|
Dictionary with status and a list of messages including
|
||||||
id, sender, content, timestamp.
|
id, sender, content, timestamp.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
limit = min(limit, 50)
|
limit = min(limit, 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_teams_connector(
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
|
||||||
token = await get_access_token(db_session, connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
params={"$top": limit},
|
|
||||||
)
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
if resp.status_code == 401:
|
token = await get_access_token(db_session, connector)
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Teams token expired. Please re-authenticate.",
|
|
||||||
"connector_type": "teams",
|
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Insufficient permissions to read this channel.",
|
|
||||||
}
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Graph API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
raw_msgs = resp.json().get("value", [])
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
messages = []
|
resp = await client.get(
|
||||||
for m in raw_msgs:
|
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||||
sender = m.get("from", {})
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
user_info = sender.get("user", {}) if sender else {}
|
params={"$top": limit},
|
||||||
body = m.get("body", {})
|
)
|
||||||
messages.append(
|
|
||||||
{
|
if resp.status_code == 401:
|
||||||
"id": m.get("id"),
|
return {
|
||||||
"sender": user_info.get("displayName", "Unknown"),
|
"status": "auth_error",
|
||||||
"content": body.get("content", ""),
|
"message": "Teams token expired. Please re-authenticate.",
|
||||||
"content_type": body.get("contentType", "text"),
|
"connector_type": "teams",
|
||||||
"timestamp": m.get("createdDateTime", ""),
|
}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "Insufficient permissions to read this channel.",
|
||||||
|
}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Graph API error: {resp.status_code}",
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
raw_msgs = resp.json().get("value", [])
|
||||||
"status": "success",
|
messages = []
|
||||||
"team_id": team_id,
|
for m in raw_msgs:
|
||||||
"channel_id": channel_id,
|
sender = m.get("from", {})
|
||||||
"messages": messages,
|
user_info = sender.get("user", {}) if sender else {}
|
||||||
"total": len(messages),
|
body = m.get("body", {})
|
||||||
}
|
messages.append(
|
||||||
|
{
|
||||||
|
"id": m.get("id"),
|
||||||
|
"sender": user_info.get("displayName", "Unknown"),
|
||||||
|
"content": body.get("content", ""),
|
||||||
|
"content_type": body.get("contentType", "text"),
|
||||||
|
"timestamp": m.get("createdDateTime", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"team_id": team_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"messages": messages,
|
||||||
|
"total": len(messages),
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||||
|
|
||||||
|
|
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Factory function to create the send_teams_message tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured send_teams_message tool
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def send_teams_message(
|
async def send_teams_message(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
|
|
@ -39,70 +57,73 @@ def create_send_teams_message_tool(
|
||||||
IMPORTANT:
|
IMPORTANT:
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
"""
|
"""
|
||||||
if db_session is None or search_space_id is None or user_id is None:
|
if search_space_id is None or user_id is None:
|
||||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
async with async_session_maker() as db_session:
|
||||||
if not connector:
|
connector = await get_teams_connector(
|
||||||
return {"status": "error", "message": "No Teams connector found."}
|
db_session, search_space_id, user_id
|
||||||
|
)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Teams connector found."}
|
||||||
|
|
||||||
result = request_approval(
|
result = request_approval(
|
||||||
action_type="teams_send_message",
|
action_type="teams_send_message",
|
||||||
tool_name="send_teams_message",
|
tool_name="send_teams_message",
|
||||||
params={
|
params={
|
||||||
"team_id": team_id,
|
"team_id": team_id,
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
},
|
|
||||||
context={"connector_id": connector.id},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Message was not sent.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_team = result.params.get("team_id", team_id)
|
|
||||||
final_channel = result.params.get("channel_id", channel_id)
|
|
||||||
|
|
||||||
token = await get_access_token(db_session, connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
},
|
||||||
json={"body": {"content": final_content}},
|
context={"connector_id": connector.id},
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if result.rejected:
|
||||||
return {
|
return {
|
||||||
"status": "auth_error",
|
"status": "rejected",
|
||||||
"message": "Teams token expired. Please re-authenticate.",
|
"message": "User declined. Message was not sent.",
|
||||||
"connector_type": "teams",
|
}
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "insufficient_permissions",
|
|
||||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
|
||||||
}
|
|
||||||
if resp.status_code not in (200, 201):
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
|
|
||||||
}
|
|
||||||
|
|
||||||
msg_data = resp.json()
|
final_content = result.params.get("content", content)
|
||||||
return {
|
final_team = result.params.get("team_id", team_id)
|
||||||
"status": "success",
|
final_channel = result.params.get("channel_id", channel_id)
|
||||||
"message_id": msg_data.get("id"),
|
|
||||||
"message": "Message sent to Teams channel.",
|
token = await get_access_token(db_session, connector)
|
||||||
}
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"body": {"content": final_content}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {
|
||||||
|
"status": "auth_error",
|
||||||
|
"message": "Teams token expired. Please re-authenticate.",
|
||||||
|
"connector_type": "teams",
|
||||||
|
}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {
|
||||||
|
"status": "insufficient_permissions",
|
||||||
|
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||||
|
}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
msg_data = resp.json()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": msg_data.get("id"),
|
||||||
|
"message": "Message sent to Teams channel.",
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from langgraph.errors import GraphInterrupt
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from langchain_core.tools import tool
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import SearchSpace, User
|
from app.db import SearchSpace, User, async_session_maker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -295,6 +295,25 @@ def create_update_memory_tool(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the user-memory update tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
The session's bound ``commit``/``rollback`` methods are captured at
|
||||||
|
call time, after ``async with`` has bound ``db_session`` locally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: ID of the user whose memory document is being updated.
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
llm: Optional LLM for the forced-rewrite path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_memory tool for the user-memory scope.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -311,26 +330,26 @@ def create_update_memory_tool(
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(select(User).where(User.id == uid))
|
async with async_session_maker() as db_session:
|
||||||
user = result.scalars().first()
|
result = await db_session.execute(select(User).where(User.id == uid))
|
||||||
if not user:
|
user = result.scalars().first()
|
||||||
return {"status": "error", "message": "User not found."}
|
if not user:
|
||||||
|
return {"status": "error", "message": "User not found."}
|
||||||
|
|
||||||
old_memory = user.memory_md
|
old_memory = user.memory_md
|
||||||
|
|
||||||
return await _save_memory(
|
return await _save_memory(
|
||||||
updated_memory=updated_memory,
|
updated_memory=updated_memory,
|
||||||
old_memory=old_memory,
|
old_memory=old_memory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||||
commit_fn=db_session.commit,
|
commit_fn=db_session.commit,
|
||||||
rollback_fn=db_session.rollback,
|
rollback_fn=db_session.rollback,
|
||||||
label="memory",
|
label="memory",
|
||||||
scope="user",
|
scope="user",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update user memory: %s", e)
|
logger.exception("Failed to update user memory: %s", e)
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"Failed to update memory: {e}",
|
"message": f"Failed to update memory: {e}",
|
||||||
|
|
@ -344,6 +363,27 @@ def create_update_team_memory_tool(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
llm: Any | None = None,
|
llm: Any | None = None,
|
||||||
):
|
):
|
||||||
|
"""Factory function to create the team-memory update tool.
|
||||||
|
|
||||||
|
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||||
|
:data:`async_session_maker` so the closure is safe to share across
|
||||||
|
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||||
|
session here would surface stale/closed sessions on cache hits.
|
||||||
|
The session's bound ``commit``/``rollback`` methods are captured at
|
||||||
|
call time, after ``async with`` has bound ``db_session`` locally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: ID of the search space whose team memory is being
|
||||||
|
updated.
|
||||||
|
db_session: Reserved for registry compatibility. Per-call sessions
|
||||||
|
are opened via :data:`async_session_maker` inside the tool body.
|
||||||
|
llm: Optional LLM for the forced-rewrite path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured update_memory tool for the team-memory scope.
|
||||||
|
"""
|
||||||
|
del db_session # per-call session — see docstring
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||||
"""Update the team's shared memory document for this search space.
|
"""Update the team's shared memory document for this search space.
|
||||||
|
|
@ -359,28 +399,30 @@ def create_update_team_memory_tool(
|
||||||
updated_memory: The FULL updated markdown document (not a diff).
|
updated_memory: The FULL updated markdown document (not a diff).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
async with async_session_maker() as db_session:
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
result = await db_session.execute(
|
||||||
)
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
space = result.scalars().first()
|
)
|
||||||
if not space:
|
space = result.scalars().first()
|
||||||
return {"status": "error", "message": "Search space not found."}
|
if not space:
|
||||||
|
return {"status": "error", "message": "Search space not found."}
|
||||||
|
|
||||||
old_memory = space.shared_memory_md
|
old_memory = space.shared_memory_md
|
||||||
|
|
||||||
return await _save_memory(
|
return await _save_memory(
|
||||||
updated_memory=updated_memory,
|
updated_memory=updated_memory,
|
||||||
old_memory=old_memory,
|
old_memory=old_memory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
apply_fn=lambda content: setattr(
|
||||||
commit_fn=db_session.commit,
|
space, "shared_memory_md", content
|
||||||
rollback_fn=db_session.rollback,
|
),
|
||||||
label="team memory",
|
commit_fn=db_session.commit,
|
||||||
scope="team",
|
rollback_fn=db_session.rollback,
|
||||||
)
|
label="team memory",
|
||||||
|
scope="team",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to update team memory: %s", e)
|
logger.exception("Failed to update team memory: %s", e)
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"Failed to update team memory: {e}",
|
"message": f"Failed to update team memory: {e}",
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from app.config import (
|
||||||
initialize_image_gen_router,
|
initialize_image_gen_router,
|
||||||
initialize_llm_router,
|
initialize_llm_router,
|
||||||
initialize_openrouter_integration,
|
initialize_openrouter_integration,
|
||||||
|
initialize_pricing_registration,
|
||||||
initialize_vision_llm_router,
|
initialize_vision_llm_router,
|
||||||
)
|
)
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
|
|
@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
|
||||||
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
||||||
|
|
||||||
|
|
||||||
|
async def _warm_agent_jit_caches() -> None:
|
||||||
|
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
|
||||||
|
|
||||||
|
Why
|
||||||
|
----
|
||||||
|
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
|
||||||
|
generation chain takes 1.5-2 seconds of pure CPU on first invocation
|
||||||
|
inside any Python process: the graph compiler builds reducers,
|
||||||
|
Pydantic v2 generates and JITs validator schemas, deepagents
|
||||||
|
eagerly compiles its general-purpose subagent, etc. Subsequent
|
||||||
|
compiles in the same process pay only ~50% of that cost (the lazy
|
||||||
|
JIT bits are cached in module-level dicts).
|
||||||
|
|
||||||
|
Doing one throwaway compile during ``lifespan`` startup pre-pays
|
||||||
|
that cost so the *first real request* doesn't. We do NOT prime
|
||||||
|
:mod:`agent_cache` because the cache key requires real
|
||||||
|
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the
|
||||||
|
throwaway agent is genuinely thrown away and immediately collected.
|
||||||
|
|
||||||
|
Safety
|
||||||
|
------
|
||||||
|
* No DB access. We construct a stub LLM (no real keys), pass an
|
||||||
|
empty tools list, and pass ``checkpointer=None`` so we never
|
||||||
|
touch Postgres.
|
||||||
|
* Bounded by ``asyncio.wait_for`` so a hang here can never block
|
||||||
|
worker startup. On any failure, we log + swallow — the worst
|
||||||
|
case is the first real request pays the full cold cost (i.e.
|
||||||
|
pre-warmup behaviour).
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
t0 = _time.perf_counter()
|
||||||
|
try:
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import (
|
||||||
|
ModelCallLimitMiddleware,
|
||||||
|
TodoListMiddleware,
|
||||||
|
ToolCallLimitMiddleware,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.fake_chat_models import (
|
||||||
|
FakeListChatModel,
|
||||||
|
)
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
|
|
||||||
|
# Minimal LLM stub. ``FakeListChatModel`` satisfies
|
||||||
|
# ``BaseChatModel`` without any network or auth — perfect for
|
||||||
|
# exercising the compile path without side effects.
|
||||||
|
stub_llm = FakeListChatModel(responses=["warmup-response"])
|
||||||
|
|
||||||
|
# Two trivial tools with arg + return schemas — exercises the
|
||||||
|
# Pydantic v2 schema JIT path. Without at least one tool the
|
||||||
|
# graph compile skips the tool-loop bytecode generation that
|
||||||
|
# accounts for ~30-50% of cold compile cost.
|
||||||
|
@tool
|
||||||
|
def _warmup_tool_a(query: str, limit: int = 5) -> str:
|
||||||
|
"""Warmup tool A — never actually invoked."""
|
||||||
|
return query[:limit]
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
|
||||||
|
"""Warmup tool B — never actually invoked."""
|
||||||
|
return {"name": name, "value": value}
|
||||||
|
|
||||||
|
# A handful of common middleware so the compile pre-pays the
|
||||||
|
# ``AgentMiddleware`` resolver path. These instances never run
|
||||||
|
# because the throwaway agent is immediately collected.
|
||||||
|
# ``SubAgentMiddleware`` is the single heaviest line in cold
|
||||||
|
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
|
||||||
|
# compile its general-purpose subagent's full inner graph),
|
||||||
|
# so we include it here to make sure that compile path is JIT'd.
|
||||||
|
warmup_middleware: list = [
|
||||||
|
TodoListMiddleware(),
|
||||||
|
ModelCallLimitMiddleware(
|
||||||
|
thread_limit=120, run_limit=80, exit_behavior="end"
|
||||||
|
),
|
||||||
|
ToolCallLimitMiddleware(
|
||||||
|
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
from deepagents import SubAgentMiddleware
|
||||||
|
from deepagents.backends import StateBackend
|
||||||
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
|
|
||||||
|
gp_warmup_spec = { # type: ignore[var-annotated]
|
||||||
|
**GENERAL_PURPOSE_SUBAGENT,
|
||||||
|
"model": stub_llm,
|
||||||
|
"tools": [_warmup_tool_a],
|
||||||
|
"middleware": [TodoListMiddleware()],
|
||||||
|
}
|
||||||
|
warmup_middleware.append(
|
||||||
|
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Deepagents missing/incompatible — middleware-only warmup
|
||||||
|
# still produces a useful (smaller) speedup.
|
||||||
|
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
|
||||||
|
|
||||||
|
compiled = create_agent(
|
||||||
|
stub_llm,
|
||||||
|
tools=[_warmup_tool_a, _warmup_tool_b],
|
||||||
|
system_prompt="You are a warmup stub.",
|
||||||
|
middleware=warmup_middleware,
|
||||||
|
context_schema=SurfSenseContextSchema,
|
||||||
|
checkpointer=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Touch the compiled graph's stream_channels / nodes so any
|
||||||
|
# remaining lazy schema work fires now instead of on first
|
||||||
|
# real invocation.
|
||||||
|
_ = list(getattr(compiled, "nodes", {}).keys())
|
||||||
|
|
||||||
|
del compiled
|
||||||
|
logger.info(
|
||||||
|
"[startup] Agent JIT warmup completed in %.3fs",
|
||||||
|
_time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
|
||||||
|
"real request will pay the full compile cost)",
|
||||||
|
_time.perf_counter() - t0,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||||
|
|
@ -432,6 +562,7 @@ async def lifespan(app: FastAPI):
|
||||||
await setup_checkpointer_tables()
|
await setup_checkpointer_tables()
|
||||||
initialize_openrouter_integration()
|
initialize_openrouter_integration()
|
||||||
_start_openrouter_background_refresh()
|
_start_openrouter_background_refresh()
|
||||||
|
initialize_pricing_registration()
|
||||||
initialize_llm_router()
|
initialize_llm_router()
|
||||||
initialize_image_gen_router()
|
initialize_image_gen_router()
|
||||||
initialize_vision_llm_router()
|
initialize_vision_llm_router()
|
||||||
|
|
@ -443,6 +574,18 @@ async def lifespan(app: FastAPI):
|
||||||
"Docs will be indexed on the next restart."
|
"Docs will be indexed on the next restart."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||||
|
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||||
|
# doesn't leave half-warmed Pydantic schemas in an inconsistent
|
||||||
|
# state.
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
|
||||||
|
except (TimeoutError, Exception): # pragma: no cover - defensive
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"[startup] Agent JIT warmup hit timeout/error — skipping; "
|
||||||
|
"first real request will pay the full compile cost."
|
||||||
|
)
|
||||||
|
|
||||||
log_system_snapshot("startup_complete")
|
log_system_snapshot("startup_complete")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
@ -452,6 +595,23 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
|
|
||||||
def registration_allowed():
|
def registration_allowed():
|
||||||
|
"""Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
|
||||||
|
|
||||||
|
Despite the name, this dependency does NOT only gate registration. When
|
||||||
|
REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
|
||||||
|
that could mint or refresh a session for an attacker:
|
||||||
|
|
||||||
|
* email/password ``POST /auth/register``
|
||||||
|
* email/password ``POST /auth/jwt/login``
|
||||||
|
* the Google OAuth router (``/auth/google/authorize`` and the shared
|
||||||
|
``/auth/google/callback`` handles both new signups and login for
|
||||||
|
existing users, so flipping this off locks both)
|
||||||
|
* the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
|
||||||
|
|
||||||
|
Use it as a temporary "freeze all new sessions" lever during incident
|
||||||
|
response. It is not a way to disable signup while keeping login working;
|
||||||
|
for that, override ``UserManager.oauth_callback`` instead.
|
||||||
|
"""
|
||||||
if not config.REGISTRATION_ENABLED:
|
if not config.REGISTRATION_ENABLED:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
|
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
|
||||||
|
|
@ -596,32 +756,45 @@ app.add_middleware(
|
||||||
allow_headers=["*"], # Allows all headers
|
allow_headers=["*"], # Allows all headers
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
# Password / email-based auth routers are only mounted when not running in
|
||||||
fastapi_users.get_auth_router(auth_backend),
|
# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
|
||||||
prefix="/auth/jwt",
|
# POST /auth/register reachable, which is the bypass that allowed bots to
|
||||||
tags=["auth"],
|
# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
|
||||||
dependencies=[Depends(rate_limit_login)],
|
if config.AUTH_TYPE != "GOOGLE":
|
||||||
)
|
app.include_router(
|
||||||
app.include_router(
|
fastapi_users.get_auth_router(auth_backend),
|
||||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
prefix="/auth/jwt",
|
||||||
prefix="/auth",
|
tags=["auth"],
|
||||||
tags=["auth"],
|
dependencies=[
|
||||||
dependencies=[
|
Depends(rate_limit_login),
|
||||||
Depends(rate_limit_register),
|
Depends(
|
||||||
Depends(registration_allowed), # blocks registration when disabled
|
registration_allowed
|
||||||
],
|
), # honour REGISTRATION_ENABLED kill switch on login too
|
||||||
)
|
],
|
||||||
app.include_router(
|
)
|
||||||
fastapi_users.get_reset_password_router(),
|
app.include_router(
|
||||||
prefix="/auth",
|
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||||
tags=["auth"],
|
prefix="/auth",
|
||||||
dependencies=[Depends(rate_limit_password_reset)],
|
tags=["auth"],
|
||||||
)
|
dependencies=[
|
||||||
app.include_router(
|
Depends(rate_limit_register),
|
||||||
fastapi_users.get_verify_router(UserRead),
|
Depends(registration_allowed),
|
||||||
prefix="/auth",
|
],
|
||||||
tags=["auth"],
|
)
|
||||||
)
|
app.include_router(
|
||||||
|
fastapi_users.get_reset_password_router(),
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["auth"],
|
||||||
|
dependencies=[Depends(rate_limit_password_reset)],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_verify_router(UserRead),
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# /users/me (read/update profile) is needed in every auth mode, so it stays
|
||||||
|
# mounted unconditionally.
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||||
prefix="/users",
|
prefix="/users",
|
||||||
|
|
@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
),
|
),
|
||||||
prefix="/auth/google",
|
prefix="/auth/google",
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
dependencies=[
|
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
|
||||||
Depends(registration_allowed)
|
# it blocks BOTH new OAuth signups AND login of existing OAuth users
|
||||||
], # blocks OAuth registration when disabled
|
# (the fastapi-users OAuth router shares one callback for create+login,
|
||||||
|
# so this dependency closes both paths together).
|
||||||
|
dependencies=[Depends(registration_allowed)],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
||||||
# This endpoint performs a server-side redirect instead of returning JSON
|
# This endpoint performs a server-side redirect instead of returning JSON
|
||||||
# which fixes cross-site cookie issues where browsers don't send cookies
|
# which fixes cross-site cookie issues where browsers don't send cookies
|
||||||
# set via cross-origin fetch requests on subsequent redirects
|
# set via cross-origin fetch requests on subsequent redirects.
|
||||||
@app.get("/auth/google/authorize-redirect", tags=["auth"])
|
# The registration_allowed dependency mirrors the OAuth router above so
|
||||||
|
# the kill switch fails fast here instead of bouncing users to Google
|
||||||
|
# only to 403 on the callback.
|
||||||
|
@app.get(
|
||||||
|
"/auth/google/authorize-redirect",
|
||||||
|
tags=["auth"],
|
||||||
|
dependencies=[Depends(registration_allowed)],
|
||||||
|
)
|
||||||
async def google_authorize_redirect(
|
async def google_authorize_redirect(
|
||||||
request: Request,
|
request: Request,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,12 @@ def init_worker(**kwargs):
|
||||||
initialize_image_gen_router,
|
initialize_image_gen_router,
|
||||||
initialize_llm_router,
|
initialize_llm_router,
|
||||||
initialize_openrouter_integration,
|
initialize_openrouter_integration,
|
||||||
|
initialize_pricing_registration,
|
||||||
initialize_vision_llm_router,
|
initialize_vision_llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
initialize_openrouter_integration()
|
initialize_openrouter_integration()
|
||||||
|
initialize_pricing_registration()
|
||||||
initialize_llm_router()
|
initialize_llm_router()
|
||||||
initialize_image_gen_router()
|
initialize_image_gen_router()
|
||||||
initialize_vision_llm_router()
|
initialize_vision_llm_router()
|
||||||
|
|
|
||||||
|
|
@ -47,11 +47,37 @@ def load_global_llm_configs():
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
configs = data.get("global_llm_configs", [])
|
configs = data.get("global_llm_configs", [])
|
||||||
|
|
||||||
|
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||||
|
# and matches the `provider_api_base` pattern used elsewhere.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
seen_slugs: dict[str, int] = {}
|
seen_slugs: dict[str, int] = {}
|
||||||
for cfg in configs:
|
for cfg in configs:
|
||||||
cfg.setdefault("billing_tier", "free")
|
cfg.setdefault("billing_tier", "free")
|
||||||
cfg.setdefault("anonymous_enabled", False)
|
cfg.setdefault("anonymous_enabled", False)
|
||||||
cfg.setdefault("seo_enabled", False)
|
cfg.setdefault("seo_enabled", False)
|
||||||
|
# Capability flag: explicit YAML override always wins. When the
|
||||||
|
# operator has not annotated the model, defer to LiteLLM's
|
||||||
|
# authoritative model map (`supports_vision`) which already
|
||||||
|
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
|
||||||
|
# vision-capable. Unknown / unmapped models default-allow so
|
||||||
|
# we don't lock the user out of a freshly added third-party
|
||||||
|
# entry; the streaming-task safety net (driven by
|
||||||
|
# `is_known_text_only_chat_model`) is the only place a False
|
||||||
|
# actually blocks a request.
|
||||||
|
if "supports_image_input" not in cfg:
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
cfg["supports_image_input"] = derive_supports_image_input(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
model_name=cfg.get("model_name"),
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=cfg.get("custom_provider"),
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
||||||
slug = cfg["seo_slug"]
|
slug = cfg["seo_slug"]
|
||||||
|
|
@ -63,6 +89,27 @@ def load_global_llm_configs():
|
||||||
else:
|
else:
|
||||||
seen_slugs[slug] = cfg.get("id", 0)
|
seen_slugs[slug] = cfg.get("id", 0)
|
||||||
|
|
||||||
|
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||||
|
# Tier A — operator-curated, locked first when premium-eligible.
|
||||||
|
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||||
|
# whose provider == "OPENROUTER" via _enrich_health.
|
||||||
|
try:
|
||||||
|
from app.services.quality_score import static_score_yaml
|
||||||
|
|
||||||
|
for cfg in configs:
|
||||||
|
cfg["auto_pin_tier"] = "A"
|
||||||
|
static_q = static_score_yaml(cfg)
|
||||||
|
cfg["quality_score_static"] = static_q
|
||||||
|
cfg["quality_score"] = static_q
|
||||||
|
cfg["quality_score_health"] = None
|
||||||
|
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||||
|
# to health gating against their own /endpoints data — a
|
||||||
|
# hand-picked dead OR model is still dead. _enrich_health
|
||||||
|
# re-stamps health_gated for them on the next refresh tick.
|
||||||
|
cfg["health_gated"] = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to score global LLM configs: {e}")
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||||
|
|
@ -117,7 +164,11 @@ def load_global_image_gen_configs():
|
||||||
try:
|
try:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
return data.get("global_image_generation_configs", [])
|
configs = data.get("global_image_generation_configs", []) or []
|
||||||
|
for cfg in configs:
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
cfg.setdefault("billing_tier", "free")
|
||||||
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
@ -132,7 +183,11 @@ def load_global_vision_llm_configs():
|
||||||
try:
|
try:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
return data.get("global_vision_llm_configs", [])
|
configs = data.get("global_vision_llm_configs", []) or []
|
||||||
|
for cfg in configs:
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
cfg.setdefault("billing_tier", "free")
|
||||||
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
@ -194,6 +249,9 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
"""
|
"""
|
||||||
Load OpenRouter integration settings from the YAML config.
|
Load OpenRouter integration settings from the YAML config.
|
||||||
|
|
||||||
|
Emits startup warnings for deprecated keys (``billing_tier``,
|
||||||
|
``anonymous_enabled``) and seeds their replacements for back-compat.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with settings if present and enabled, None otherwise
|
dict with settings if present and enabled, None otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -206,9 +264,40 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
settings = data.get("openrouter_integration")
|
settings = data.get("openrouter_integration")
|
||||||
if settings and settings.get("enabled"):
|
if not settings or not settings.get("enabled"):
|
||||||
return settings
|
return None
|
||||||
return None
|
|
||||||
|
if "billing_tier" in settings:
|
||||||
|
print(
|
||||||
|
"Warning: openrouter_integration.billing_tier is deprecated; "
|
||||||
|
"tier is now derived per model from OpenRouter data "
|
||||||
|
"(':free' suffix or zero pricing). Remove this key."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "anonymous_enabled" in settings:
|
||||||
|
print(
|
||||||
|
"Warning: openrouter_integration.anonymous_enabled is "
|
||||||
|
"deprecated; use anonymous_enabled_paid and/or "
|
||||||
|
"anonymous_enabled_free instead. Both new flags have been "
|
||||||
|
"seeded from the legacy value for back-compat."
|
||||||
|
)
|
||||||
|
settings.setdefault(
|
||||||
|
"anonymous_enabled_paid", settings["anonymous_enabled"]
|
||||||
|
)
|
||||||
|
settings.setdefault(
|
||||||
|
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image generation + vision LLM emission are opt-in (issue L).
|
||||||
|
# OpenRouter's catalogue contains hundreds of image / vision
|
||||||
|
# capable models; auto-injecting all of them into every
|
||||||
|
# deployment would explode the model selector and surprise
|
||||||
|
# operators upgrading from prior versions. Default to False so
|
||||||
|
# admins must explicitly turn them on.
|
||||||
|
settings.setdefault("image_generation_enabled", False)
|
||||||
|
settings.setdefault("vision_enabled", False)
|
||||||
|
|
||||||
|
return settings
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -217,9 +306,14 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
def initialize_openrouter_integration():
|
def initialize_openrouter_integration():
|
||||||
"""
|
"""
|
||||||
If enabled, fetch all OpenRouter models and append them to
|
If enabled, fetch all OpenRouter models and append them to
|
||||||
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
|
config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
|
||||||
Should be called BEFORE initialize_llm_router() so the router
|
is derived per-model from OpenRouter's API signals (``:free`` suffix or
|
||||||
correctly excludes premium models from Auto mode.
|
zero pricing), so free OpenRouter models correctly skip premium quota.
|
||||||
|
|
||||||
|
Should be called BEFORE initialize_llm_router(). Dynamic entries are
|
||||||
|
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
|
||||||
|
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
|
||||||
|
while user-facing Auto-mode thread pinning still considers them.
|
||||||
"""
|
"""
|
||||||
settings = load_openrouter_integration_settings()
|
settings = load_openrouter_integration_settings()
|
||||||
if not settings:
|
if not settings:
|
||||||
|
|
@ -235,16 +329,70 @@ def initialize_openrouter_integration():
|
||||||
|
|
||||||
if new_configs:
|
if new_configs:
|
||||||
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
||||||
|
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
|
||||||
|
premium_count = sum(
|
||||||
|
1 for c in new_configs if c.get("billing_tier") == "premium"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"Info: OpenRouter integration added {len(new_configs)} models "
|
f"Info: OpenRouter integration added {len(new_configs)} models "
|
||||||
f"(billing_tier={settings.get('billing_tier', 'premium')})"
|
f"(free={free_count}, premium={premium_count})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Info: OpenRouter integration enabled but no models fetched")
|
print("Info: OpenRouter integration enabled but no models fetched")
|
||||||
|
|
||||||
|
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||||
|
# Both reuse the catalogue already cached by ``service.initialize``
|
||||||
|
# so we don't make additional network calls here.
|
||||||
|
if settings.get("image_generation_enabled"):
|
||||||
|
try:
|
||||||
|
image_configs = service.get_image_generation_configs()
|
||||||
|
if image_configs:
|
||||||
|
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
|
||||||
|
print(
|
||||||
|
f"Info: OpenRouter integration added {len(image_configs)} "
|
||||||
|
f"image-generation models"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||||
|
|
||||||
|
if settings.get("vision_enabled"):
|
||||||
|
try:
|
||||||
|
vision_configs = service.get_vision_llm_configs()
|
||||||
|
if vision_configs:
|
||||||
|
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||||
|
print(
|
||||||
|
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||||
|
f"vision LLM models"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_pricing_registration():
|
||||||
|
"""
|
||||||
|
Teach LiteLLM the per-token cost of every deployment in
|
||||||
|
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
|
||||||
|
from the OpenRouter catalogue + any operator-declared YAML pricing).
|
||||||
|
|
||||||
|
Must run AFTER ``initialize_openrouter_integration()`` so the
|
||||||
|
OpenRouter catalogue is populated and BEFORE the first LLM call so
|
||||||
|
``response_cost`` is available in ``TokenTrackingCallback``.
|
||||||
|
|
||||||
|
Failures are logged but never raised — startup must not be blocked
|
||||||
|
by a missing pricing entry; the worst-case is the model debits 0.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.pricing_registration import (
|
||||||
|
register_pricing_from_global_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_pricing_from_global_configs()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to register LiteLLM pricing: {e}")
|
||||||
|
|
||||||
|
|
||||||
def initialize_llm_router():
|
def initialize_llm_router():
|
||||||
"""
|
"""
|
||||||
Initialize the LLM Router service for Auto mode.
|
Initialize the LLM Router service for Auto mode.
|
||||||
|
|
@ -389,14 +537,54 @@ class Config:
|
||||||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Premium token quota settings
|
# Premium credit (micro-USD) quota settings.
|
||||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
#
|
||||||
|
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||||
|
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||||
|
# still honoured for one release as fall-back values — the prior
|
||||||
|
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||||
|
# to micros, so operators upgrading without changing their .env still
|
||||||
|
# get correct behaviour. A startup deprecation warning fires below if
|
||||||
|
# they're set.
|
||||||
|
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||||
|
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||||
|
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||||
|
)
|
||||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||||
|
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||||
|
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||||
|
)
|
||||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||||
|
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||||
|
# config's ``quota_reserve_tokens`` and clamps the result to this value
|
||||||
|
# so a misconfigured "$1000/M" model can't lock the user's whole balance
|
||||||
|
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
|
||||||
|
# reserve_tokens ≈ $0.36) with headroom.
|
||||||
|
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||||
|
|
||||||
|
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||||
|
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||||
|
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||||
|
"current Stripe price). The old key will be removed in a "
|
||||||
|
"future release."
|
||||||
|
)
|
||||||
|
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||||
|
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
|
||||||
|
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||||
|
"The old key will be removed in a future release."
|
||||||
|
)
|
||||||
|
|
||||||
# Anonymous / no-login mode settings
|
# Anonymous / no-login mode settings
|
||||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||||
MULTI_AGENT_CHAT_ENABLED = (
|
MULTI_AGENT_CHAT_ENABLED = (
|
||||||
|
|
@ -412,6 +600,35 @@ class Config:
|
||||||
# Default quota reserve tokens when not specified per-model
|
# Default quota reserve tokens when not specified per-model
|
||||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||||
|
|
||||||
|
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
|
||||||
|
# ``POST /image-generations`` endpoint when the global config does not
|
||||||
|
# override it. $0.05 covers realistic worst-cases for current OpenAI /
|
||||||
|
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
|
||||||
|
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||||
|
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||||
|
# premium-model run. Tune via env.
|
||||||
|
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-video-presentation reservation (in micro-USD). Fan-out of N
|
||||||
|
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
|
||||||
|
# plus refine retries; can produce many premium completions. $1.00
|
||||||
|
# covers worst-case. Tune via env.
|
||||||
|
#
|
||||||
|
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
|
||||||
|
# 1_000_000. The override path in ``billable_call`` bypasses the
|
||||||
|
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
|
||||||
|
# *actual* hold — raising it via env is fine but means a single video
|
||||||
|
# task can lock $1+ of credit.
|
||||||
|
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
|
||||||
|
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
|
||||||
|
)
|
||||||
|
|
||||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,24 @@
|
||||||
# Structure matches NewLLMConfig:
|
# Structure matches NewLLMConfig:
|
||||||
# - Model configuration (provider, model_name, api_key, etc.)
|
# - Model configuration (provider, model_name, api_key, etc.)
|
||||||
# - Prompt configuration (system_instructions, citations_enabled)
|
# - Prompt configuration (system_instructions, citations_enabled)
|
||||||
|
#
|
||||||
|
# COST-BASED PREMIUM CREDITS:
|
||||||
|
# Each premium config bills the user's USD-credit balance based on the
|
||||||
|
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||||
|
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||||
|
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||||
|
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||||
|
# per-token costs inline so they bill correctly:
|
||||||
|
#
|
||||||
|
# litellm_params:
|
||||||
|
# base_model: "my-custom-azure-deploy"
|
||||||
|
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||||
|
# input_cost_per_token: 0.000003
|
||||||
|
# output_cost_per_token: 0.000015
|
||||||
|
#
|
||||||
|
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||||
|
# API — no inline declaration needed. Models without resolvable pricing
|
||||||
|
# debit $0 from the user's balance and log a WARNING.
|
||||||
|
|
||||||
# Router Settings for Auto Mode
|
# Router Settings for Auto Mode
|
||||||
# These settings control how the LiteLLM Router distributes requests across models
|
# These settings control how the LiteLLM Router distributes requests across models
|
||||||
|
|
@ -245,31 +263,64 @@ global_llm_configs:
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||||
# and injects them as global configs. This gives premium users access to any model
|
# and injects them as global configs. This gives premium users access to any model
|
||||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
|
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||||
|
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||||
|
# consume premium quota.
|
||||||
# Models are fetched at startup and refreshed periodically in the background.
|
# Models are fetched at startup and refreshed periodically in the background.
|
||||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||||
openrouter_integration:
|
openrouter_integration:
|
||||||
enabled: false
|
enabled: false
|
||||||
api_key: "sk-or-your-openrouter-api-key"
|
api_key: "sk-or-your-openrouter-api-key"
|
||||||
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
|
|
||||||
billing_tier: "premium"
|
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||||
# anonymous_enabled: set true to also show OpenRouter models to no-login users
|
# - id ends with ":free" -> billing_tier=free
|
||||||
anonymous_enabled: false
|
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||||
|
# - otherwise -> billing_tier=premium
|
||||||
|
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||||
|
|
||||||
|
# Anonymous access is split by tier so operators can expose only free
|
||||||
|
# models to no-login users without leaking paid inference.
|
||||||
|
anonymous_enabled_paid: false
|
||||||
|
anonymous_enabled_free: false
|
||||||
|
|
||||||
seo_enabled: false
|
seo_enabled: false
|
||||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||||
quota_reserve_tokens: 4000
|
quota_reserve_tokens: 4000
|
||||||
# id_offset: starting negative ID for dynamically generated configs.
|
# id_offset: base negative ID for dynamically generated configs.
|
||||||
# Must not overlap with your static global_llm_configs IDs above.
|
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||||
|
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||||
id_offset: -10000
|
id_offset: -10000
|
||||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||||
refresh_interval_hours: 24
|
refresh_interval_hours: 24
|
||||||
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
|
|
||||||
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
|
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||||
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
|
# for per-deployment accounting when OR premium models participate in the
|
||||||
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
|
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||||
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
|
# real account limits live at https://openrouter.ai/settings/limits.
|
||||||
rpm: 200
|
rpm: 200
|
||||||
tpm: 1000000
|
tpm: 1000000
|
||||||
|
|
||||||
|
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||||
|
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||||
|
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||||
|
# 50-1000 daily requests across every ":free" model combined) —
|
||||||
|
# per-deployment router accounting can't represent a shared bucket
|
||||||
|
# correctly. Free OR models stay fully available in the model selector
|
||||||
|
# and for user-facing Auto thread pinning.
|
||||||
|
free_rpm: 20
|
||||||
|
free_tpm: 100000
|
||||||
|
|
||||||
|
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||||
|
# contains hundreds of image- and vision-capable models; turning these on
|
||||||
|
# injects them into the global Image-Generation / Vision-LLM model
|
||||||
|
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||||
|
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||||
|
# When a user picks a premium image/vision model the call debits the
|
||||||
|
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||||
|
# avoids surprise quota burn on existing deployments. Default: false.
|
||||||
|
image_generation_enabled: false
|
||||||
|
vision_enabled: false
|
||||||
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
max_tokens: 16384
|
max_tokens: 16384
|
||||||
system_instructions: ""
|
system_instructions: ""
|
||||||
|
|
|
||||||
|
|
@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
|
||||||
default=False,
|
default=False,
|
||||||
server_default="false",
|
server_default="false",
|
||||||
)
|
)
|
||||||
|
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||||
|
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||||
|
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||||
|
# or clears this column (plus bulk clears when a search space's
|
||||||
|
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||||
|
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||||
|
|
@ -669,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
||||||
|
|
||||||
__tablename__ = "new_chat_messages"
|
__tablename__ = "new_chat_messages"
|
||||||
|
|
||||||
|
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
|
||||||
|
# Mirrors alembic migration 141. Lets the streaming agent and the
|
||||||
|
# legacy frontend appendMessage call coexist idempotently — the second
|
||||||
|
# writer trips the unique and recovers without creating a duplicate row.
|
||||||
|
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
|
||||||
|
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"uq_new_chat_messages_thread_turn_role",
|
||||||
|
"thread_id",
|
||||||
|
"turn_id",
|
||||||
|
"role",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("turn_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
||||||
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
||||||
content = Column(JSONB, nullable=False)
|
content = Column(JSONB, nullable=False)
|
||||||
|
|
@ -722,9 +745,26 @@ class TokenUsage(BaseModel, TimestampMixin):
|
||||||
|
|
||||||
__tablename__ = "token_usage"
|
__tablename__ = "token_usage"
|
||||||
|
|
||||||
|
# Partial unique index on (message_id) where message_id IS NOT NULL.
|
||||||
|
# Mirrors alembic migration 142. Lets the streaming agent's
|
||||||
|
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
|
||||||
|
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
|
||||||
|
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
|
||||||
|
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
|
||||||
|
# because there is no per-message anchor — are unaffected.
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"uq_token_usage_message_id",
|
||||||
|
"message_id",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("message_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||||
total_tokens = Column(Integer, nullable=False, default=0)
|
total_tokens = Column(Integer, nullable=False, default=0)
|
||||||
|
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||||
model_breakdown = Column(JSONB, nullable=True)
|
model_breakdown = Column(JSONB, nullable=True)
|
||||||
call_details = Column(JSONB, nullable=True)
|
call_details = Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
|
@ -1787,7 +1827,15 @@ class PagePurchase(Base, TimestampMixin):
|
||||||
|
|
||||||
|
|
||||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||||
|
|
||||||
|
Note: the table name is preserved (``premium_token_purchases``) for
|
||||||
|
operational continuity even though the unit is now USD micro-credits
|
||||||
|
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||||
|
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||||
|
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||||
|
makes the unit conversion 1:1 numerically.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "premium_token_purchases"
|
__tablename__ = "premium_token_purchases"
|
||||||
__allow_unmapped__ = True
|
__allow_unmapped__ = True
|
||||||
|
|
@ -1804,7 +1852,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
||||||
)
|
)
|
||||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||||
quantity = Column(Integer, nullable=False)
|
quantity = Column(Integer, nullable=False)
|
||||||
tokens_granted = Column(BigInteger, nullable=False)
|
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||||
amount_total = Column(Integer, nullable=True)
|
amount_total = Column(Integer, nullable=True)
|
||||||
currency = Column(String(10), nullable=True)
|
currency = Column(String(10), nullable=True)
|
||||||
status = Column(
|
status = Column(
|
||||||
|
|
@ -2103,16 +2151,16 @@ if config.AUTH_TYPE == "GOOGLE":
|
||||||
)
|
)
|
||||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
premium_tokens_limit = Column(
|
premium_credit_micros_limit = Column(
|
||||||
BigInteger,
|
BigInteger,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=config.PREMIUM_TOKEN_LIMIT,
|
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||||
)
|
)
|
||||||
premium_tokens_used = Column(
|
premium_credit_micros_used = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
premium_tokens_reserved = Column(
|
premium_credit_micros_reserved = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2235,16 +2283,16 @@ else:
|
||||||
)
|
)
|
||||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
premium_tokens_limit = Column(
|
premium_credit_micros_limit = Column(
|
||||||
BigInteger,
|
BigInteger,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=config.PREMIUM_TOKEN_LIMIT,
|
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||||
)
|
)
|
||||||
premium_tokens_used = Column(
|
premium_credit_micros_used = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
premium_tokens_reserved = Column(
|
premium_credit_micros_reserved = Column(
|
||||||
BigInteger, nullable=False, default=0, server_default="0"
|
BigInteger, nullable=False, default=0, server_default="0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,25 @@ class EtlPipelineService:
|
||||||
etl_service="VISION_LLM",
|
etl_service="VISION_LLM",
|
||||||
content_type="image",
|
content_type="image",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
logging.warning(
|
# Special-case quota exhaustion so we log a clearer message
|
||||||
"Vision LLM failed for %s, falling back to document parser",
|
# — the vision LLM didn't "fail", the user just ran out of
|
||||||
request.filename,
|
# premium credit. Falling through to the document parser
|
||||||
exc_info=True,
|
# is a graceful degradation: OCR/Unstructured still
|
||||||
)
|
# extracts text from the image without burning credit.
|
||||||
|
from app.services.billable_calls import QuotaInsufficientError
|
||||||
|
|
||||||
|
if isinstance(exc, QuotaInsufficientError):
|
||||||
|
logging.info(
|
||||||
|
"Vision LLM quota exhausted for %s; falling back to document parser",
|
||||||
|
request.filename,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"Vision LLM failed for %s, falling back to document parser",
|
||||||
|
request.filename,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logging.info(
|
||||||
"No vision LLM provided, falling back to document parser for %s",
|
"No vision LLM provided, falling back to document parser for %s",
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue