mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 06:12:40 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction
This commit is contained in:
commit
b981b51ab1
176 changed files with 20407 additions and 6258 deletions
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.0.19
|
||||
0.0.20
|
||||
|
|
|
|||
|
|
@ -159,10 +159,13 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
|
|||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||
|
||||
# Premium token purchases ($1 per 1M tokens for premium-tier models)
|
||||
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||
# credit; premium turns debit the actual per-call provider cost
|
||||
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# DEPRECATED — STRIPE_TOKENS_PER_UNIT=1000000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
|
|
@ -305,6 +308,24 @@ STT_SERVICE=local/base
|
|||
# Advanced (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# New-chat agent feature flags
|
||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||
SURFSENSE_ENABLE_BUSY_MUTEX=true
|
||||
SURFSENSE_ENABLE_SKILLS=true
|
||||
SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=true
|
||||
SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=true
|
||||
SURFSENSE_ENABLE_ACTION_LOG=true
|
||||
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||
|
||||
# Periodic connector sync interval (default: 5m)
|
||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||
|
||||
|
|
@ -315,9 +336,24 @@ STT_SERVICE=local/base
|
|||
# Pages limit per user for ETL (default: unlimited)
|
||||
# PAGES_LIMIT=500
|
||||
|
||||
# Premium token quota per registered user (default: 5M)
|
||||
# Only applies to models with billing_tier=premium in global_llm_config.yaml
|
||||
# PREMIUM_TOKEN_LIMIT=5000000
|
||||
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||
# Premium turns are debited at the actual per-call provider cost reported
|
||||
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# DEPRECATED — PREMIUM_TOKEN_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
||||
# Per-image reservation for POST /image-generations, in micro-USD ($0.05 default).
|
||||
# QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||
|
||||
# Per-podcast reservation for the podcast Celery task ($0.20 default).
|
||||
# QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||
|
||||
# Per-video-presentation reservation for the video Celery task ($1.00 default).
|
||||
# Override path bypasses QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||
# QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||
|
||||
# No-login (anonymous) mode — public users can chat without an account
|
||||
# Set TRUE to enable /free pages and anonymous chat API
|
||||
|
|
|
|||
|
|
@ -54,11 +54,15 @@ STRIPE_PAGES_PER_UNIT=1000
|
|||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||
|
||||
# Premium token purchases via Stripe (for premium-tier model usage)
|
||||
# Set TRUE to allow users to buy premium token packs ($1 per 1M tokens)
|
||||
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||
# per-call provider cost reported by LiteLLM.
|
||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
STRIPE_TOKENS_PER_UNIT=1000000
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
|
|
@ -184,9 +188,35 @@ VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
|||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
|
||||
# Premium token quota per registered user (default: 3,000,000)
|
||||
# Applies only to models with billing_tier=premium in global_llm_config.yaml
|
||||
PREMIUM_TOKEN_LIMIT=3000000
|
||||
# Premium credit quota per registered user, in micro-USD
|
||||
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||
# models bill proportionally. Applies only to models with
|
||||
# billing_tier=premium in global_llm_config.yaml.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||
# PREMIUM_TOKEN_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||
# stream_new_chat estimates an upper-bound cost from the model's
|
||||
# litellm-published per-token rates × the config's quota_reserve_tokens
|
||||
# and clamps to this value so a misconfigured model can't lock the
|
||||
# user's whole balance on one call. Default $1.00.
|
||||
QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
||||
# Per-image reservation (in micro-USD) for the POST /image-generations
|
||||
# endpoint. Bypassed for free configs. Default $0.05.
|
||||
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS=50000
|
||||
|
||||
# Per-podcast reservation (in micro-USD) used by the podcast Celery task.
|
||||
# Single envelope covers one transcript-generation LLM call. Default $0.20.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS=200000
|
||||
|
||||
# Per-video-presentation reservation (in micro-USD) used by the video
|
||||
# presentation Celery task. Covers worst-case fan-out of N slide-scene
|
||||
# generations + refines. Default $1.00. NOTE: tasks using the override
|
||||
# path bypass the QUOTA_MAX_RESERVE_MICROS clamp — raise with care.
|
||||
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS=1000000
|
||||
|
||||
# No-login (anonymous) mode — allows public users to chat without an account
|
||||
# Set TRUE to enable /free pages and anonymous chat API
|
||||
|
|
@ -294,3 +324,30 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||
# Comma-separated allowlist of plugin entry-point names
|
||||
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compiled-agent cache (Phase 1 + 2 perf optimization, default ON)
|
||||
# -----------------------------------------------------------------------------
|
||||
# When ON, the per-turn LangGraph + middleware compile result (~3-5s of CPU
|
||||
# on a cold turn) is reused across subsequent turns on the same thread,
|
||||
# collapsing it to a microsecond hash lookup. All connector tools acquire
|
||||
# their own short-lived DB session per call (Phase 2 refactor) so a cached
|
||||
# closure is safe to share across requests. Flip OFF only as a last-resort
|
||||
# rollback if you suspect cache-related staleness.
|
||||
# SURFSENSE_ENABLE_AGENT_CACHE=true
|
||||
|
||||
# Cache capacity (max number of compiled-agent entries kept in memory)
|
||||
# and TTL per entry (seconds). Working set is typically one entry per
|
||||
# active thread on this replica; tune up for very large deployments.
|
||||
# SURFSENSE_AGENT_CACHE_MAXSIZE=256
|
||||
# SURFSENSE_AGENT_CACHE_TTL_SECONDS=1800
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Connector discovery TTL cache (Phase 1.4 perf optimization)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Caches the per-search-space "available connectors" + "available document
|
||||
# types" lookups that ``create_surfsense_deep_agent`` hits on every turn.
|
||||
# ORM event listeners auto-invalidate on connector / document inserts,
|
||||
# updates and deletes — the TTL only bounds staleness for bulk-import
|
||||
# paths that bypass the ORM. Set to 0 to disable the cache.
|
||||
# SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30
|
||||
|
|
|
|||
|
|
@ -38,16 +38,26 @@ RUN pip install --upgrade certifi pip-system-certs
|
|||
COPY pyproject.toml .
|
||||
COPY uv.lock .
|
||||
|
||||
# Install PyTorch based on architecture
|
||||
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
||||
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121; \
|
||||
else \
|
||||
pip install --no-cache-dir torch torchvision torchaudio; \
|
||||
fi
|
||||
|
||||
# Install python dependencies
|
||||
# Install all Python dependencies from uv.lock for deterministic builds.
|
||||
#
|
||||
# `uv pip install -e .` re-resolves from pyproject.toml and ignores uv.lock,
|
||||
# which lets prod silently drift to newer upstream versions on every rebuild
|
||||
# (e.g. deepagents 0.4.x -> 0.5.x breaking the FilesystemMiddleware imports).
|
||||
# Exporting the lock to requirements.txt and feeding it to `uv pip install`
|
||||
# pins every transitive package to the exact version captured in uv.lock.
|
||||
#
|
||||
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
|
||||
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
|
||||
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
|
||||
# captured in uv.lock). Installing from cu121 first only wasted ~2GB of
|
||||
# downloads that the lock-based install immediately replaced. If a specific
|
||||
# CUDA version is needed (driver compatibility, etc.), wire it through
|
||||
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
|
||||
RUN pip install --no-cache-dir uv && \
|
||||
uv pip install --system --no-cache-dir -e .
|
||||
uv export --frozen --no-dev --no-hashes --no-emit-project \
|
||||
--format requirements-txt -o /tmp/requirements.txt && \
|
||||
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
|
||||
rm /tmp/requirements.txt
|
||||
|
||||
# Set SSL environment variables dynamically
|
||||
RUN CERTIFI_PATH=$(python -c "import certifi; print(certifi.where())") && \
|
||||
|
|
@ -66,13 +76,18 @@ RUN cd /root/.EasyOCR/model && (unzip -o english_g2.zip || true) && (unzip -o cr
|
|||
# Pre-download Docling models
|
||||
RUN python -c "try:\n from docling.document_converter import DocumentConverter\n conv = DocumentConverter()\nexcept:\n pass" || true
|
||||
|
||||
# Install Playwright browsers for web scraping if needed
|
||||
RUN pip install playwright && \
|
||||
playwright install chromium --with-deps
|
||||
# Install Playwright browsers for web scraping (the playwright package itself
|
||||
# is already installed via uv.lock above)
|
||||
RUN playwright install chromium --with-deps
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Install the project itself in editable mode. Dependencies were already
|
||||
# installed deterministically from uv.lock above, so --no-deps prevents any
|
||||
# re-resolution that could pull newer versions.
|
||||
RUN uv pip install --system --no-cache-dir --no-deps -e .
|
||||
|
||||
# Copy and set permissions for entrypoint script
|
||||
# Use dos2unix to ensure LF line endings (fixes CRLF issues from Windows checkouts)
|
||||
COPY scripts/docker/entrypoint.sh /app/scripts/docker/entrypoint.sh
|
||||
|
|
|
|||
|
|
@ -0,0 +1,291 @@
|
|||
"""rename premium token columns to credit micros and add cost_micros to token_usage
|
||||
|
||||
Migrates the premium quota system from a flat token counter to a USD-cost
|
||||
based credit system, where 1 credit = 1 micro-USD ($0.000001).
|
||||
|
||||
Column renames (1:1 numerical mapping — the prior $1 per 1M tokens Stripe
|
||||
price means every existing value is already correct in the new unit, no
|
||||
data transformation needed):
|
||||
|
||||
user.premium_tokens_limit -> premium_credit_micros_limit
|
||||
user.premium_tokens_used -> premium_credit_micros_used
|
||||
user.premium_tokens_reserved -> premium_credit_micros_reserved
|
||||
|
||||
premium_token_purchases.tokens_granted -> credit_micros_granted
|
||||
|
||||
New column for cost auditing per turn:
|
||||
|
||||
token_usage.cost_micros (BigInteger NOT NULL DEFAULT 0)
|
||||
|
||||
The "user" table is in zero_publication's column list (added in 139), so
|
||||
this migration must drop and recreate the publication with the renamed
|
||||
column names, otherwise zero-cache will replicate stale column names and
|
||||
the FE Zero schema will fail to bind.
|
||||
|
||||
IMPORTANT - before AND after running this migration:
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Skipping the zero-cache stop will deadlock at the ACCESS EXCLUSIVE LOCK on
|
||||
"user". Skipping the data-volume reset will leave IndexedDB clients seeing
|
||||
column-not-found errors from a stale catalog snapshot.
|
||||
|
||||
Revision ID: 140
|
||||
Revises: 139
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "140"
|
||||
down_revision: str | None = "139"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
# Replicates 139's document column list verbatim — must stay in sync.
|
||||
DOCUMENT_COLS = [
|
||||
"id",
|
||||
"title",
|
||||
"document_type",
|
||||
"search_space_id",
|
||||
"folder_id",
|
||||
"created_by_id",
|
||||
"status",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
# Same five live-meter fields as 139, with the renamed column names.
|
||||
USER_COLS = [
|
||||
"id",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
]
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def _has_zero_version(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||
),
|
||||
{"tbl": table},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _build_publication_ddl(
|
||||
user_cols: list[str],
|
||||
*,
|
||||
documents_has_zero_ver: bool,
|
||||
user_has_zero_ver: bool,
|
||||
) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||
user_col_list_with_meta = user_cols + (
|
||||
['"_0_version"'] if user_has_zero_ver else []
|
||||
)
|
||||
doc_col_list = ", ".join(doc_cols)
|
||||
user_col_list = ", ".join(user_col_list_with_meta)
|
||||
return (
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||
f"notifications, "
|
||||
f"documents ({doc_col_list}), "
|
||||
f"folders, "
|
||||
f"search_source_connectors, "
|
||||
f"new_chat_messages, "
|
||||
f"chat_comments, "
|
||||
f"chat_session_state, "
|
||||
f'"user" ({user_col_list})'
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Add cost_micros to token_usage. Idempotent guard so re-runs in
|
||||
# dev environments are safe.
|
||||
# ------------------------------------------------------------------
|
||||
if not _column_exists(conn, "token_usage", "cost_micros"):
|
||||
op.add_column(
|
||||
"token_usage",
|
||||
sa.Column(
|
||||
"cost_micros",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Rename premium_token_purchases.tokens_granted -> credit_micros_granted.
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "premium_token_purchases", "tokens_granted"
|
||||
) and not _column_exists(conn, "premium_token_purchases", "credit_micros_granted"):
|
||||
op.alter_column(
|
||||
"premium_token_purchases",
|
||||
"tokens_granted",
|
||||
new_column_name="credit_micros_granted",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Rename user.premium_tokens_* -> premium_credit_micros_*.
|
||||
#
|
||||
# We must drop the publication first (it references the old column
|
||||
# names) and re-acquire the lock for DDL. asyncpg requires LOCK TABLE
|
||||
# in a transaction block; alembic's outer transaction already holds
|
||||
# one, but a SAVEPOINT keeps the LOCK + DDL atomic.
|
||||
# ------------------------------------------------------------------
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
# Re-assert REPLICA IDENTITY DEFAULT for safety; column-list
|
||||
# publications require at least the PK to be in the column list,
|
||||
# which is true for both the old and new shape.
|
||||
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||
|
||||
# Drop the publication BEFORE renaming columns, otherwise Postgres
|
||||
# rejects the rename: "cannot drop column ... referenced by
|
||||
# publication".
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
for old, new in (
|
||||
("premium_tokens_limit", "premium_credit_micros_limit"),
|
||||
("premium_tokens_used", "premium_credit_micros_used"),
|
||||
("premium_tokens_reserved", "premium_credit_micros_reserved"),
|
||||
):
|
||||
if _column_exists(conn, "user", old) and not _column_exists(
|
||||
conn, "user", new
|
||||
):
|
||||
op.alter_column("user", old, new_column_name=new)
|
||||
|
||||
# Update the server_default on the renamed limit column so newly
|
||||
# inserted users get $5 of credit (== 5_000_000 micros) by
|
||||
# default. Existing rows are unaffected.
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_credit_micros_limit",
|
||||
server_default="5000000",
|
||||
)
|
||||
|
||||
# Recreate the publication with the new column names.
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||
conn.execute(
|
||||
sa.text(
|
||||
_build_publication_ddl(
|
||||
USER_COLS,
|
||||
documents_has_zero_ver=documents_has_zero_ver,
|
||||
user_has_zero_ver=user_has_zero_ver,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert the rename and drop ``cost_micros``.
|
||||
|
||||
Mirrors ``upgrade``: drop the publication, rename columns back, drop
|
||||
the new column, recreate the publication with the old column list.
|
||||
Same zero-cache stop/reset runbook applies in reverse.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
for new, old in (
|
||||
("premium_credit_micros_limit", "premium_tokens_limit"),
|
||||
("premium_credit_micros_used", "premium_tokens_used"),
|
||||
("premium_credit_micros_reserved", "premium_tokens_reserved"),
|
||||
):
|
||||
if _column_exists(conn, "user", new) and not _column_exists(
|
||||
conn, "user", old
|
||||
):
|
||||
op.alter_column("user", new, new_column_name=old)
|
||||
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_tokens_limit",
|
||||
server_default="5000000",
|
||||
)
|
||||
|
||||
legacy_user_cols = [
|
||||
"id",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
"premium_tokens_limit",
|
||||
"premium_tokens_used",
|
||||
]
|
||||
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||
conn.execute(
|
||||
sa.text(
|
||||
_build_publication_ddl(
|
||||
legacy_user_cols,
|
||||
documents_has_zero_ver=documents_has_zero_ver,
|
||||
user_has_zero_ver=user_has_zero_ver,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if _column_exists(
|
||||
conn, "premium_token_purchases", "credit_micros_granted"
|
||||
) and not _column_exists(conn, "premium_token_purchases", "tokens_granted"):
|
||||
op.alter_column(
|
||||
"premium_token_purchases",
|
||||
"credit_micros_granted",
|
||||
new_column_name="tokens_granted",
|
||||
)
|
||||
|
||||
if _column_exists(conn, "token_usage", "cost_micros"):
|
||||
op.drop_column("token_usage", "cost_micros")
|
||||
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
357
surfsense_backend/app/agents/new_chat/agent_cache.py
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
"""TTL-LRU cache for compiled SurfSense deep agents.
|
||||
|
||||
Why this exists
|
||||
---------------
|
||||
|
||||
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
|
||||
turn:
|
||||
|
||||
1. Discover connectors & document types from Postgres (~50-200ms)
|
||||
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
|
||||
3. Compose the system prompt
|
||||
4. Construct ~15 middleware instances (CPU)
|
||||
5. Eagerly compile the general-purpose subagent
|
||||
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
|
||||
which builds a second LangGraph + Pydantic schemas — ~1.5-2s of pure
|
||||
CPU work)
|
||||
6. Compile the outer LangGraph
|
||||
|
||||
For a single thread, all six steps produce the SAME object on every turn
|
||||
unless the user has changed their LLM config, toggled a feature flag,
|
||||
added a connector, etc. The right answer is to compile ONCE per
|
||||
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
|
||||
every subsequent turn on the same thread.
|
||||
|
||||
Why a per-thread key (not a global pool)
|
||||
----------------------------------------
|
||||
|
||||
Most middleware in the SurfSense stack captures per-thread state in
|
||||
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
|
||||
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
|
||||
would silently leak state across users and threads. Keying the cache on
|
||||
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
|
||||
turns on the same thread without changing any middleware's behavior.
|
||||
|
||||
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
|
||||
(read via ``runtime.context``) so the cache can collapse to a single
|
||||
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
|
||||
then, per-thread keying is the only safe option.
|
||||
|
||||
Cache shape
|
||||
-----------
|
||||
|
||||
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
|
||||
minutes — matches a typical chat session). ``maxsize`` (default 256)
|
||||
caps memory; LRU evicts least-recently-used on overflow.
|
||||
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
|
||||
cold misses on the same key wait for the first build instead of
|
||||
building N times.
|
||||
* Process-local: this is an in-memory cache. Multi-replica deployments
|
||||
pay the build cost once per replica per key. That's fine; the working
|
||||
set per replica is small (one entry per active thread on that replica).
|
||||
|
||||
Telemetry
|
||||
---------
|
||||
|
||||
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
|
||||
|
||||
* ``hit`` — cache hit, microseconds-fast
|
||||
* ``miss`` — first build for this key, includes build duration
|
||||
* ``stale`` — entry was found but expired; rebuilt
|
||||
* ``evict`` — LRU eviction (size-limited)
|
||||
* ``size`` — current cache occupancy at lookup time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API: signature helpers (cache key components)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def stable_hash(*parts: Any) -> str:
|
||||
"""Compute a deterministic SHA1 of the str repr of ``parts``.
|
||||
|
||||
Used for cache key components that need a fixed-width representation
|
||||
(system prompt, tool list, etc.). SHA1 is fine here — this is not a
|
||||
security boundary, just a content fingerprint.
|
||||
"""
|
||||
h = hashlib.sha1(usedforsecurity=False)
|
||||
for p in parts:
|
||||
h.update(repr(p).encode("utf-8", errors="replace"))
|
||||
h.update(b"\x1f") # ASCII unit separator between parts
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def tools_signature(
|
||||
tools: list[Any] | tuple[Any, ...],
|
||||
*,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
) -> str:
|
||||
"""Hash the bound-tool surface for cache-key purposes.
|
||||
|
||||
The signature changes whenever:
|
||||
|
||||
* A tool is added or removed from the bound list (built-in toggles,
|
||||
MCP tools loaded for the user changes, gating rules flip, etc.).
|
||||
* The available connectors / document types for the search space
|
||||
change (new connector added, last connector removed, new document
|
||||
type indexed). Because :func:`get_connector_gated_tools` derives
|
||||
``modified_disabled_tools`` from ``available_connectors``, the
|
||||
tool surface is technically already covered — but we hash the
|
||||
connector list separately so an empty-list "no tools changed"
|
||||
situation still rotates the key when, say, the user re-adds a
|
||||
connector that gates a tool we were already not exposing.
|
||||
|
||||
Stays stable across:
|
||||
|
||||
* Process restarts (tool names + descriptions are static).
|
||||
* Different replicas (everyone gets the same hash for the same
|
||||
inputs).
|
||||
"""
|
||||
tool_descriptors = sorted(
|
||||
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
|
||||
)
|
||||
connectors = sorted(available_connectors or [])
|
||||
doc_types = sorted(available_document_types or [])
|
||||
return stable_hash(tool_descriptors, connectors, doc_types)
|
||||
|
||||
|
||||
def flags_signature(flags: Any) -> str:
|
||||
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
|
||||
|
||||
Frozen dataclasses are deterministically reprable, so a SHA1 of their
|
||||
repr is a stable fingerprint. Restart safe (flags are read once at
|
||||
process boot).
|
||||
"""
|
||||
return stable_hash(repr(flags))
|
||||
|
||||
|
||||
def system_prompt_hash(system_prompt: str) -> str:
|
||||
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
|
||||
return hashlib.sha1(
|
||||
system_prompt.encode("utf-8", errors="replace"),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Entry:
|
||||
value: Any
|
||||
created_at: float
|
||||
last_used_at: float
|
||||
|
||||
|
||||
class _AgentCache:
|
||||
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
|
||||
|
||||
NOT THREAD-SAFE in the multithreading sense — designed for a single
|
||||
asyncio event loop. Uvicorn runs one event loop per worker process,
|
||||
so this is fine; multi-worker deployments simply each maintain their
|
||||
own cache.
|
||||
"""
|
||||
|
||||
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
|
||||
self._maxsize = maxsize
|
||||
self._ttl = ttl_seconds
|
||||
self._entries: OrderedDict[str, _Entry] = OrderedDict()
|
||||
# One lock per key — guards "build" so concurrent cold misses on
|
||||
# the same key wait for the first build instead of all racing.
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _now(self) -> float:
|
||||
return time.monotonic()
|
||||
|
||||
def _is_fresh(self, entry: _Entry) -> bool:
|
||||
return (self._now() - entry.created_at) < self._ttl
|
||||
|
||||
def _evict_if_full(self) -> None:
|
||||
while len(self._entries) >= self._maxsize:
|
||||
evicted_key, _ = self._entries.popitem(last=False)
|
||||
self._locks.pop(evicted_key, None)
|
||||
_perf_log.info(
|
||||
"[agent_cache] evict key=%s reason=lru size=%d",
|
||||
_short(evicted_key),
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def _touch(self, key: str, entry: _Entry) -> None:
|
||||
entry.last_used_at = self._now()
|
||||
self._entries.move_to_end(key, last=True)
|
||||
|
||||
async def get_or_build(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
builder: Callable[[], Awaitable[Any]],
|
||||
) -> Any:
|
||||
"""Return the cached value for ``key`` or call ``builder()`` to make it.
|
||||
|
||||
``builder`` MUST be idempotent — concurrent cold misses on the
|
||||
same key collapse to a single ``builder()`` call (the others
|
||||
wait on the in-flight lock and observe the populated entry on
|
||||
wake).
|
||||
"""
|
||||
# Fast path: hot hit.
|
||||
entry = self._entries.get(key)
|
||||
if entry is not None and self._is_fresh(entry):
|
||||
self._touch(key, entry)
|
||||
_perf_log.info(
|
||||
"[agent_cache] hit key=%s age=%.1fs size=%d",
|
||||
_short(key),
|
||||
self._now() - entry.created_at,
|
||||
len(self._entries),
|
||||
)
|
||||
return entry.value
|
||||
|
||||
# Stale entry — drop it; rebuild below.
|
||||
if entry is not None and not self._is_fresh(entry):
|
||||
_perf_log.info(
|
||||
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
|
||||
_short(key),
|
||||
self._now() - entry.created_at,
|
||||
self._ttl,
|
||||
)
|
||||
self._entries.pop(key, None)
|
||||
|
||||
# Slow path: serialize concurrent misses for the same key.
|
||||
lock = self._locks.setdefault(key, asyncio.Lock())
|
||||
async with lock:
|
||||
# Double-check after acquiring the lock — another waiter may
|
||||
# have populated the entry while we slept.
|
||||
entry = self._entries.get(key)
|
||||
if entry is not None and self._is_fresh(entry):
|
||||
self._touch(key, entry)
|
||||
_perf_log.info(
|
||||
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
|
||||
_short(key),
|
||||
self._now() - entry.created_at,
|
||||
len(self._entries),
|
||||
)
|
||||
return entry.value
|
||||
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
value = await builder()
|
||||
except BaseException:
|
||||
# Don't cache failed builds; let the next caller retry.
|
||||
_perf_log.warning(
|
||||
"[agent_cache] build_failed key=%s elapsed=%.3fs",
|
||||
_short(key),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
# Insert + evict.
|
||||
self._evict_if_full()
|
||||
now = self._now()
|
||||
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
|
||||
self._entries.move_to_end(key, last=True)
|
||||
_perf_log.info(
|
||||
"[agent_cache] miss key=%s build=%.3fs size=%d",
|
||||
_short(key),
|
||||
elapsed,
|
||||
len(self._entries),
|
||||
)
|
||||
return value
|
||||
|
||||
def invalidate(self, key: str) -> bool:
|
||||
"""Drop a single entry; return True if anything was removed."""
|
||||
removed = self._entries.pop(key, None) is not None
|
||||
self._locks.pop(key, None)
|
||||
if removed:
|
||||
_perf_log.info(
|
||||
"[agent_cache] invalidate key=%s size=%d",
|
||||
_short(key),
|
||||
len(self._entries),
|
||||
)
|
||||
return removed
|
||||
|
||||
def invalidate_prefix(self, prefix: str) -> int:
|
||||
"""Drop every entry whose key starts with ``prefix``. Returns count."""
|
||||
keys = [k for k in self._entries if k.startswith(prefix)]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
self._locks.pop(k, None)
|
||||
if keys:
|
||||
_perf_log.info(
|
||||
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
|
||||
_short(prefix),
|
||||
len(keys),
|
||||
len(self._entries),
|
||||
)
|
||||
return len(keys)
|
||||
|
||||
def clear(self) -> None:
|
||||
n = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._locks.clear()
|
||||
if n:
|
||||
_perf_log.info("[agent_cache] clear removed=%d", n)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
return {
|
||||
"size": len(self._entries),
|
||||
"maxsize": self._maxsize,
|
||||
"ttl_seconds": self._ttl,
|
||||
}
|
||||
|
||||
|
||||
def _short(key: str, n: int = 16) -> str:
|
||||
"""Truncate keys for log lines so they don't blow up log volume."""
|
||||
return key if len(key) <= n else f"{key[:n]}..."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
|
||||
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
|
||||
|
||||
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
|
||||
|
||||
|
||||
def get_cache() -> _AgentCache:
|
||||
"""Return the process-wide compiled-agent cache singleton."""
|
||||
return _cache
|
||||
|
||||
|
||||
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
|
||||
"""Replace the singleton with a fresh cache. Tests only."""
|
||||
global _cache
|
||||
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
|
||||
return _cache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"flags_signature",
|
||||
"get_cache",
|
||||
"reload_for_tests",
|
||||
"stable_hash",
|
||||
"system_prompt_hash",
|
||||
"tools_signature",
|
||||
]
|
||||
|
|
@ -40,6 +40,13 @@ from langchain_core.tools import BaseTool
|
|||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
|
|
@ -53,6 +60,7 @@ from app.agents.new_chat.middleware import (
|
|||
DedupHITLToolCallsMiddleware,
|
||||
DoomLoopMiddleware,
|
||||
FileIntentMiddleware,
|
||||
FlattenSystemMessageMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
|
|
@ -330,23 +338,39 @@ async def create_surfsense_deep_agent(
|
|||
else None,
|
||||
)
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
# Discover available connectors and document types for this search space.
|
||||
#
|
||||
# NOTE: These two calls cannot be parallelized via ``asyncio.gather``.
|
||||
# ``ConnectorService`` shares a single ``AsyncSession`` (``self.session``);
|
||||
# SQLAlchemy explicitly forbids concurrent operations on the same session
|
||||
# ("This session is provisioning a new connection; concurrent operations
|
||||
# are not permitted on the same session"). The Phase 1.4 in-process TTL
|
||||
# cache in ``connector_service`` already collapses the warm path to a
|
||||
# near-zero pair of dict lookups, so sequential awaits cost nothing in
|
||||
# the common case while remaining correct on cold cache misses.
|
||||
available_connectors: list[str] | None = None
|
||||
available_document_types: list[str] | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
connector_types = await connector_service.get_available_connectors(
|
||||
search_space_id
|
||||
)
|
||||
if connector_types:
|
||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
||||
try:
|
||||
connector_types_result = await connector_service.get_available_connectors(
|
||||
search_space_id
|
||||
)
|
||||
if connector_types_result:
|
||||
available_connectors = _map_connectors_to_searchable_types(
|
||||
connector_types_result
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to discover available connectors: %s", e)
|
||||
|
||||
available_document_types = await connector_service.get_available_document_types(
|
||||
search_space_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
available_document_types = (
|
||||
await connector_service.get_available_document_types(search_space_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to discover available document types: %s", e)
|
||||
except Exception as e: # pragma: no cover - defensive outer guard
|
||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
|
|
@ -469,29 +493,77 @@ async def create_surfsense_deep_agent(
|
|||
# entire middleware build + main-graph compile into a single
|
||||
# ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the
|
||||
# event loop stays responsive.
|
||||
#
|
||||
# PHASE 1: cache the resulting compiled graph. ``agent_cache`` is keyed
|
||||
# on every per-request value that any middleware in the stack closes
|
||||
# over in ``__init__`` — drop one and you risk leaking state across
|
||||
# threads. Hits collapse this whole block to a microsecond lookup;
|
||||
# misses pay the original CPU cost AND populate the cache.
|
||||
config_id = agent_config.config_id if agent_config is not None else None
|
||||
|
||||
async def _build_agent() -> Any:
|
||||
return await asyncio.to_thread(
|
||||
_build_compiled_agent_blocking,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
# ``mentioned_document_ids`` is consumed by
|
||||
# ``KnowledgePriorityMiddleware`` per turn via
|
||||
# ``runtime.context`` (Phase 1.5). We still pass the
|
||||
# caller-provided list here for the legacy fallback path
|
||||
# (cache disabled / context not propagated) — the middleware
|
||||
# drains its own copy after the first read so a cached graph
|
||||
# never replays stale mentions.
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=_max_input_tokens,
|
||||
flags=_flags,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
_build_compiled_agent_blocking,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_selection.mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=_max_input_tokens,
|
||||
flags=_flags,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack:
|
||||
# Cache key components — order matters only for human readability;
|
||||
# the resulting hash is what's stored. Every component must
|
||||
# rotate on a real shape change AND stay stable across identical
|
||||
# invocations.
|
||||
cache_key = stable_hash(
|
||||
"v1", # schema version of the key — bump if components change
|
||||
config_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
filesystem_selection.mode,
|
||||
anon_session_id,
|
||||
tools_signature(
|
||||
tools,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
),
|
||||
flags_signature(_flags),
|
||||
system_prompt_hash(final_system_prompt),
|
||||
_max_input_tokens,
|
||||
# ``mentioned_document_ids`` deliberately omitted — middleware
|
||||
# reads it from ``runtime.context`` (Phase 1.5).
|
||||
)
|
||||
agent = await get_cache().get_or_build(cache_key, builder=_build_agent)
|
||||
else:
|
||||
agent = await _build_agent()
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs (cache=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
"on"
|
||||
if _flags.enable_agent_cache and not _flags.disable_new_agent_stack
|
||||
else "off",
|
||||
)
|
||||
|
||||
_perf_log.info(
|
||||
|
|
@ -1038,6 +1110,14 @@ def _build_compiled_agent_blocking(
|
|||
noop_mw,
|
||||
retry_mw,
|
||||
fallback_mw,
|
||||
# Coalesce a multi-text-block system message into one block
|
||||
# immediately before the model call. Sits innermost on the
|
||||
# system-message-mutation chain so it observes every appender
|
||||
# (todo / filesystem / skills / subagents …) and prevents
|
||||
# OpenRouter→Anthropic from redistributing ``cache_control``
|
||||
# across N blocks and tripping Anthropic's 4-breakpoint cap.
|
||||
# See ``middleware/flatten_system.py`` for full rationale.
|
||||
FlattenSystemMessageMiddleware(),
|
||||
# Tool-call repair must run after model emits but before
|
||||
# permission / dedup / doom-loop interpret the calls.
|
||||
repair_mw,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,25 @@
|
|||
"""
|
||||
Context schema definitions for SurfSense agents.
|
||||
|
||||
This module defines the custom state schema used by the SurfSense deep agent.
|
||||
This module defines the per-invocation context object passed to the SurfSense
|
||||
deep agent via ``agent.astream_events(..., context=ctx)`` (LangGraph >= 0.6).
|
||||
|
||||
The agent's compiled graph is the same across invocations (and cached by
|
||||
``agent_cache``), so anything that varies per turn — the user mentions a
|
||||
specific document, the front-end issues a unique ``request_id``, etc. —
|
||||
MUST live on this context object instead of being captured into a
|
||||
middleware ``__init__`` closure. Middlewares read fields back via
|
||||
``runtime.context.<field>``; tools read them via ``runtime.context``.
|
||||
|
||||
This object is read inside both ``KnowledgePriorityMiddleware`` (for
|
||||
``mentioned_document_ids``) and any future middleware that needs
|
||||
per-request state without invalidating the compiled-agent cache.
|
||||
"""
|
||||
|
||||
from typing import NotRequired, TypedDict
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class FileOperationContractState(TypedDict):
|
||||
|
|
@ -15,25 +30,35 @@ class FileOperationContractState(TypedDict):
|
|||
turn_id: str
|
||||
|
||||
|
||||
class SurfSenseContextSchema(TypedDict):
|
||||
@dataclass
|
||||
class SurfSenseContextSchema:
|
||||
"""
|
||||
Custom state schema for the SurfSense deep agent.
|
||||
Per-invocation context for the SurfSense deep agent.
|
||||
|
||||
This extends the default agent state with custom fields.
|
||||
The default state already includes:
|
||||
- messages: Conversation history
|
||||
- todos: Task list from TodoListMiddleware
|
||||
- files: Virtual filesystem from FilesystemMiddleware
|
||||
Defaults are chosen so the dataclass can be safely default-constructed
|
||||
(LangGraph's ``Runtime.context`` itself defaults to ``None`` if no
|
||||
context is supplied — see ``langgraph.runtime.Runtime``). All fields
|
||||
are optional; consumers must None-check before reading.
|
||||
|
||||
We're adding fields needed for knowledge base search:
|
||||
- search_space_id: The user's search space ID
|
||||
- db_session: Database session (injected at runtime)
|
||||
- connector_service: Connector service instance (injected at runtime)
|
||||
Phase 1.5 fields:
|
||||
search_space_id: Search space the request is scoped to.
|
||||
mentioned_document_ids: KB documents the user @-mentioned this turn.
|
||||
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||
list. Stays out of the compiled-agent cache key — that's the
|
||||
whole point of putting it here.
|
||||
file_operation_contract: One-shot file operation contract emitted
|
||||
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||
task; populated for telemetry.
|
||||
|
||||
Phase 2 will extend with: thread_id, user_id, visibility,
|
||||
filesystem_mode, anon_session_id, available_connectors,
|
||||
available_document_types, created_by_id (everything currently captured
|
||||
by middleware ``__init__`` closures).
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
file_operation_contract: NotRequired[FileOperationContractState]
|
||||
turn_id: NotRequired[str]
|
||||
request_id: NotRequired[str]
|
||||
# These are runtime-injected and won't be serialized
|
||||
# db_session and connector_service are passed when invoking the agent
|
||||
search_space_id: int | None = None
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
file_operation_contract: FileOperationContractState | None = None
|
||||
turn_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ Feature flags for the SurfSense new_chat agent stack.
|
|||
|
||||
These flags gate the newer agent middleware (some ported from OpenCode,
|
||||
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
||||
SurfSense-native). They follow a "default-OFF for risky things,
|
||||
default-ON for safe upgrades, master kill-switch for everything new" model.
|
||||
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
||||
image updates work even when older installs do not have newly introduced
|
||||
environment variables. Risky/experimental integrations stay default OFF,
|
||||
and the master kill-switch can still disable everything new.
|
||||
|
||||
All new middleware checks its flag at agent build time. If the master
|
||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||
|
|
@ -14,16 +16,19 @@ operators a single switch to revert to pre-port behavior.
|
|||
Examples
|
||||
--------
|
||||
|
||||
Local development (recommended for trying everything except doom-loop / selector):
|
||||
Defaults:
|
||||
|
||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
||||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
|
|
@ -60,32 +65,28 @@ class AgentFeatureFlags:
|
|||
disable_new_agent_stack: bool = False
|
||||
|
||||
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
||||
enable_context_editing: bool = False
|
||||
enable_compaction_v2: bool = False
|
||||
enable_retry_after: bool = False
|
||||
enable_context_editing: bool = True
|
||||
enable_compaction_v2: bool = True
|
||||
enable_retry_after: bool = True
|
||||
enable_model_fallback: bool = False
|
||||
enable_model_call_limit: bool = False
|
||||
enable_tool_call_limit: bool = False
|
||||
enable_tool_call_repair: bool = False
|
||||
enable_doom_loop: bool = (
|
||||
False # Default OFF until UI handles permission='doom_loop'
|
||||
)
|
||||
enable_model_call_limit: bool = True
|
||||
enable_tool_call_limit: bool = True
|
||||
enable_tool_call_repair: bool = True
|
||||
enable_doom_loop: bool = True
|
||||
|
||||
# Safety — permissions, concurrency, tool-set narrowing
|
||||
enable_permission: bool = False # Default OFF for first deploy
|
||||
enable_busy_mutex: bool = False
|
||||
enable_permission: bool = True
|
||||
enable_busy_mutex: bool = True
|
||||
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||
|
||||
# Skills + subagents
|
||||
enable_skills: bool = False
|
||||
enable_specialized_subagents: bool = False
|
||||
enable_kb_planner_runnable: bool = False
|
||||
enable_skills: bool = True
|
||||
enable_specialized_subagents: bool = True
|
||||
enable_kb_planner_runnable: bool = True
|
||||
|
||||
# Snapshot / revert
|
||||
enable_action_log: bool = False
|
||||
enable_revert_route: bool = (
|
||||
False # Backend ships before UI; route returns 503 until this flips
|
||||
)
|
||||
enable_action_log: bool = True
|
||||
enable_revert_route: bool = True
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured
|
||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||
|
|
@ -94,7 +95,7 @@ class AgentFeatureFlags:
|
|||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||
# ship unconditionally because they're forward-compatible.
|
||||
enable_stream_parity_v2: bool = False
|
||||
enable_stream_parity_v2: bool = True
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
|
@ -102,6 +103,41 @@ class AgentFeatureFlags:
|
|||
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||
enable_otel: bool = False
|
||||
|
||||
# Performance — compiled-agent cache (Phase 1 + Phase 2).
|
||||
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
|
||||
# graph if the cache key matches (LLM config + thread + tool surface +
|
||||
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
|
||||
# wall clock from ~4-5s to <50µs on cache hits.
|
||||
#
|
||||
# SAFETY (Phase 2 unblocked this default-on):
|
||||
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
|
||||
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
|
||||
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
|
||||
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
|
||||
# ``update_memory``, ``search_surfsense_docs``) now acquire fresh
|
||||
# short-lived ``AsyncSession`` instances per call via
|
||||
# :data:`async_session_maker`. The factory still accepts ``db_session``
|
||||
# for registry compatibility but ``del``'s it immediately — see any
|
||||
# of those files' factory docstrings for the rationale. The ``llm``
|
||||
# closure is per-(provider, model, config_id) which is already in
|
||||
# the cache key, so the LLM is safe to share across cached hits of
|
||||
# the same key. The KB priority middleware reads
|
||||
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
|
||||
# not its constructor closure, so the same compiled agent serves
|
||||
# turns with different mention lists correctly.
|
||||
#
|
||||
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
|
||||
# environment if a regression surfaces. The path is exercised by
|
||||
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
|
||||
enable_agent_cache: bool = True
|
||||
# Phase 1 (deferred — measure first): pre-build & share the
|
||||
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
|
||||
# misses. Only helps when the outer cache MISSES (cache hits already
|
||||
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
|
||||
# until we have data showing cold misses are frequent enough to
|
||||
# justify the extra global state.
|
||||
enable_agent_cache_share_gp_subagent: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> AgentFeatureFlags:
|
||||
"""Read flags from environment.
|
||||
|
|
@ -115,48 +151,76 @@ class AgentFeatureFlags:
|
|||
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||
"middleware is forced OFF for this build."
|
||||
)
|
||||
return cls(disable_new_agent_stack=True)
|
||||
return cls(
|
||||
disable_new_agent_stack=True,
|
||||
enable_context_editing=False,
|
||||
enable_compaction_v2=False,
|
||||
enable_retry_after=False,
|
||||
enable_model_fallback=False,
|
||||
enable_model_call_limit=False,
|
||||
enable_tool_call_limit=False,
|
||||
enable_tool_call_repair=False,
|
||||
enable_doom_loop=False,
|
||||
enable_permission=False,
|
||||
enable_busy_mutex=False,
|
||||
enable_llm_tool_selector=False,
|
||||
enable_skills=False,
|
||||
enable_specialized_subagents=False,
|
||||
enable_kb_planner_runnable=False,
|
||||
enable_action_log=False,
|
||||
enable_revert_route=False,
|
||||
enable_stream_parity_v2=False,
|
||||
enable_plugin_loader=False,
|
||||
enable_otel=False,
|
||||
enable_agent_cache=False,
|
||||
enable_agent_cache_share_gp_subagent=False,
|
||||
)
|
||||
|
||||
return cls(
|
||||
disable_new_agent_stack=False,
|
||||
# Agent quality
|
||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
|
||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
|
||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
|
||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
|
||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||
enable_model_call_limit=_env_bool(
|
||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False
|
||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
|
||||
),
|
||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
|
||||
enable_tool_call_repair=_env_bool(
|
||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False
|
||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
|
||||
),
|
||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
|
||||
# Safety
|
||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
|
||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
|
||||
enable_llm_tool_selector=_env_bool(
|
||||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||
),
|
||||
# Skills + subagents
|
||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
|
||||
enable_specialized_subagents=_env_bool(
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
|
||||
),
|
||||
enable_kb_planner_runnable=_env_bool(
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||
),
|
||||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||
# Streaming parity v2
|
||||
enable_stream_parity_v2=_env_bool(
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||
),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||
# Performance
|
||||
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
|
||||
enable_agent_cache_share_gp_subagent=_env_bool(
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
|
||||
),
|
||||
)
|
||||
|
||||
def any_new_middleware_enabled(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
#
|
||||
# Single source of truth lives in
|
||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||
# runs during ``app.config`` class-body init) can resolve provider
|
||||
# prefixes without dragging the agent / tools tree into module load
|
||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||
# tests) keep working unchanged.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
|
|
@ -178,6 +155,17 @@ class AgentConfig:
|
|||
anonymous_enabled: bool = False
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
# Capability flag: best-effort True for the chat selector / catalog.
|
||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||
# is the conservative-allow stance — the streaming-task safety net
|
||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||
# actually blocks a request. Setting this to False here without an
|
||||
# authoritative source would silently hide vision-capable models
|
||||
# (the regression we're fixing).
|
||||
supports_image_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
|
|
@ -203,6 +191,12 @@ class AgentConfig:
|
|||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# contains at least one vision-capable deployment; the router
|
||||
# will surface a 404 from a non-vision deployment as a normal
|
||||
# ``allowed_fails`` event and fail over rather than blocking
|
||||
# the request outright.
|
||||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -216,10 +210,24 @@ class AgentConfig:
|
|||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
return cls(
|
||||
provider=config.provider.value
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider),
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
|
|
@ -235,6 +243,16 @@ class AgentConfig:
|
|||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no operator-curated capability flag, so we
|
||||
# ask LiteLLM (default-allow on unknown). The streaming
|
||||
# safety net still blocks if the model is *explicitly*
|
||||
# marked text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -253,15 +271,46 @@ class AgentConfig:
|
|||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
# Get system instructions from YAML, default to empty string
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||
# OpenRouter modalities. The YAML loader already populates this
|
||||
# field, but this method is also called from
|
||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||
# so we re-derive here for safety. The bool() coercion preserves
|
||||
# the loader's behaviour for explicit ``true`` / ``false``
|
||||
# strings that PyYAML may surface.
|
||||
if "supports_image_input" in yaml_config:
|
||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=yaml_config.get("provider", "").upper(),
|
||||
model_name=yaml_config.get("model_name", ""),
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=yaml_config.get("api_key", ""),
|
||||
api_base=yaml_config.get("api_base"),
|
||||
custom_provider=yaml_config.get("custom_provider"),
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=yaml_config.get("litellm_params"),
|
||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||
system_instructions=system_instructions if system_instructions else None,
|
||||
|
|
@ -276,6 +325,7 @@ class AgentConfig:
|
|||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ from app.agents.new_chat.middleware.file_intent import (
|
|||
from app.agents.new_chat.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
|
|
@ -61,6 +64,7 @@ __all__ = [
|
|||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,233 @@
|
|||
r"""Coalesce multi-block system messages into a single text block.
|
||||
|
||||
Several middlewares in our deepagent stack each call
|
||||
``append_to_system_message`` on the way down to the model
|
||||
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||
|
||||
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||
request**, and we configure 2 injection points
|
||||
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||
the prepended ``request.system_message``, this middleware is the
|
||||
defensive partner: it guarantees that "the system block" is *one*
|
||||
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||
several breakpoints by spreading ``cache_control`` across multiple
|
||||
text blocks of a multi-block system content.
|
||||
|
||||
Without flattening we used to see::
|
||||
|
||||
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||
cache_control may be provided. Found 5."}}}
|
||||
|
||||
(Same error class documented in
|
||||
https://github.com/BerriAI/litellm/issues/15696 and
|
||||
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||
in PR #15395 covers the litellm transformer but does not protect us
|
||||
when the OpenRouter SaaS itself does the redistribution.)
|
||||
|
||||
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
|
||||
the first injection point from ``role: system`` to ``index: 0``)
|
||||
neutralises the *primary* cause of the same 400 — multiple
|
||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||
matcher. This middleware remains useful as defence-in-depth against
|
||||
the multi-block redistribution path.
|
||||
|
||||
Placement: innermost on the system-message-mutation chain, after every
|
||||
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||
attempt sees a flattened payload. See ``chat_deepagent.py``.
|
||||
|
||||
Idempotent: a string-content system message is left untouched. A list
|
||||
that contains anything other than plain text blocks (e.g. an image) is
|
||||
also left untouched — those are rare on system messages and we'd lose
|
||||
the non-text payload by joining.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||
|
||||
Returns ``None`` when the list contains anything that isn't a text
|
||||
block we can safely concatenate (image, audio, file, non-standard
|
||||
blocks, dicts with extra non-cache_control fields). The caller
|
||||
leaves the original content untouched in that case rather than
|
||||
silently dropping payload.
|
||||
|
||||
``cache_control`` on individual blocks is intentionally discarded —
|
||||
the whole point of flattening is to let LiteLLM's
|
||||
``cache_control_injection_points`` re-place a single breakpoint on
|
||||
the resulting one-block system content.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
chunks.append(text)
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _flattened_request(
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT] | None:
|
||||
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
return None
|
||||
content = sys_msg.content
|
||||
if not isinstance(content, list) or len(content) <= 1:
|
||||
return None
|
||||
|
||||
flattened = _flatten_text_blocks(content)
|
||||
if flattened is None:
|
||||
return None
|
||||
|
||||
new_sys = SystemMessage(
|
||||
content=flattened,
|
||||
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||
response_metadata=dict(sys_msg.response_metadata),
|
||||
)
|
||||
if sys_msg.id is not None:
|
||||
new_sys.id = sys_msg.id
|
||||
return request.override(system_message=new_sys)
|
||||
|
||||
|
||||
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||
"""One-line dump of cache_control-relevant request shape.
|
||||
|
||||
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||
root cause is confirmed and a fix is in place.
|
||||
"""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
sys_shape = "none"
|
||||
elif isinstance(sys_msg.content, str):
|
||||
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||
elif isinstance(sys_msg.content, list):
|
||||
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||
else:
|
||||
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||
|
||||
role_hist: list[str] = []
|
||||
multi_block_msgs = 0
|
||||
msgs_with_cc = 0
|
||||
sys_msgs_in_history = 0
|
||||
for m in request.messages:
|
||||
mtype = getattr(m, "type", type(m).__name__)
|
||||
role_hist.append(mtype)
|
||||
if isinstance(m, SystemMessage):
|
||||
sys_msgs_in_history += 1
|
||||
c = getattr(m, "content", None)
|
||||
if isinstance(c, list):
|
||||
multi_block_msgs += 1
|
||||
for blk in c:
|
||||
if isinstance(blk, dict) and "cache_control" in blk:
|
||||
msgs_with_cc += 1
|
||||
break
|
||||
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||
msgs_with_cc += 1
|
||||
|
||||
tools = request.tools or []
|
||||
tools_with_cc = 0
|
||||
for t in tools:
|
||||
if isinstance(t, dict) and (
|
||||
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||
):
|
||||
tools_with_cc += 1
|
||||
|
||||
return (
|
||||
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||
f"roles={role_hist[-8:]}"
|
||||
)
|
||||
|
||||
|
||||
class FlattenSystemMessageMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Collapse a multi-text-block system message to a single string.
|
||||
|
||||
Sits innermost on the system-message-mutation chain so it observes
|
||||
every middleware's contribution. Has no other side effect — the
|
||||
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tools = []
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return handler(flattened)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return await handler(flattened)
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"_flatten_text_blocks",
|
||||
"_flattened_request",
|
||||
]
|
||||
|
|
@ -732,7 +732,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
|
|
@ -755,7 +754,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
if anon_doc:
|
||||
return self._anon_priority(state, anon_doc)
|
||||
|
||||
return await self._authenticated_priority(state, messages, user_text)
|
||||
return await self._authenticated_priority(state, messages, user_text, runtime)
|
||||
|
||||
def _anon_priority(
|
||||
self,
|
||||
|
|
@ -787,6 +786,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
state: AgentState,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
runtime: Runtime[Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
t0 = asyncio.get_event_loop().time()
|
||||
(
|
||||
|
|
@ -799,13 +799,45 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
user_text=user_text,
|
||||
)
|
||||
|
||||
# Per-turn ``mentioned_document_ids`` flow:
|
||||
# 1. Preferred path (Phase 1.5+): read from ``runtime.context`` — the
|
||||
# streaming task supplies a fresh :class:`SurfSenseContextSchema`
|
||||
# on every ``astream_events`` call, so this list is naturally
|
||||
# scoped to the current turn. Allows cross-turn graph reuse via
|
||||
# ``agent_cache``.
|
||||
# 2. Legacy fallback (cache disabled / context not propagated): the
|
||||
# constructor-injected ``self.mentioned_document_ids`` list. We
|
||||
# drain it after the first read so a cached graph (no Phase 1.5
|
||||
# wiring) doesn't keep replaying the same mentions on every
|
||||
# turn.
|
||||
#
|
||||
# CRITICAL: distinguish "context absent" (legacy caller, no field at
|
||||
# all) from "context provided but empty" (turn with no mentions).
|
||||
# ``ctx_mentions`` is a ``list[int]``; an empty list is falsy in
|
||||
# Python, so a naive ``if ctx_mentions:`` would fall through to the
|
||||
# legacy closure on every no-mention follow-up turn — replaying the
|
||||
# mentions baked in by turn 1's cache-miss build. Always drain the
|
||||
# closure once the runtime path has fired so a cached middleware
|
||||
# instance can never resurrect stale state.
|
||||
mention_ids: list[int] = []
|
||||
ctx = getattr(runtime, "context", None) if runtime is not None else None
|
||||
ctx_mentions = getattr(ctx, "mentioned_document_ids", None) if ctx else None
|
||||
if ctx_mentions is not None:
|
||||
# Runtime path is authoritative — even an empty list means
|
||||
# "this turn has no mentions", NOT "look at the closure".
|
||||
mention_ids = list(ctx_mentions)
|
||||
if self.mentioned_document_ids:
|
||||
self.mentioned_document_ids = []
|
||||
elif self.mentioned_document_ids:
|
||||
mention_ids = list(self.mentioned_document_ids)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if self.mentioned_document_ids:
|
||||
if mention_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
document_ids=self.mentioned_document_ids,
|
||||
document_ids=mention_ids,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||
|
||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||
|
|
@ -17,8 +17,20 @@ Coverage:
|
|||
|
||||
We inject **two** breakpoints per request:
|
||||
|
||||
- ``role: system`` — pins the SurfSense system prompt (provider variant,
|
||||
citation rules, tool catalog, KB tree, skills metadata) into the cache.
|
||||
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||
request (provider variant, citation rules, tool catalog, KB tree,
|
||||
skills metadata). The langchain agent factory always prepends
|
||||
``request.system_message`` at index 0 (see ``factory.py``
|
||||
``_execute_model_async``), so this targets exactly the main system
|
||||
prompt regardless of how many other ``SystemMessage``\ s the
|
||||
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||
``role: system`` here would apply ``cache_control`` to **every**
|
||||
system-role message and trip Anthropic's hard cap of 4 cache
|
||||
breakpoints per request once the conversation accumulates enough
|
||||
injected system messages — which surfaces as the upstream 400
|
||||
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||
via OpenRouter→Anthropic.
|
||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||
N+1 still reads turn N's cache up to the shared prefix.
|
||||
|
|
@ -51,11 +63,21 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Two-breakpoint policy: system + latest message. See module docstring for
|
||||
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
|
||||
# use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
#
|
||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
||||
# anonymous-doc) insert ``SystemMessage`` instances into
|
||||
# ``state["messages"]`` that accumulate across turns. With
|
||||
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
||||
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
||||
# always targets the langchain-prepended ``request.system_message``
|
||||
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
||||
# block), giving us exactly one stable cache breakpoint.
|
||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||
{"location": "message", "role": "system"},
|
||||
{"location": "message", "index": 0},
|
||||
{"location": "message", "index": -1},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -18,6 +19,23 @@ def create_create_confluence_page_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_confluence_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_confluence_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_confluence_page(
|
||||
title: str,
|
||||
|
|
@ -42,160 +60,163 @@ def create_create_confluence_page_tool(
|
|||
"""
|
||||
logger.info(f"create_confluence_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_creation",
|
||||
tool_name="create_confluence_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="confluence_page_creation",
|
||||
tool_name="create_confluence_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content) or ""
|
||||
final_space_id = result.params.get("space_id", space_id)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content) or ""
|
||||
final_space_id = result.params.get("space_id", space_id)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -18,6 +19,23 @@ def create_delete_confluence_page_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_confluence_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured delete_confluence_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_confluence_page(
|
||||
page_title_or_id: str,
|
||||
|
|
@ -43,137 +61,143 @@ def create_delete_confluence_page_tool(
|
|||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_deletion",
|
||||
tool_name="delete_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_deletion",
|
||||
tool_name="delete_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -18,6 +19,23 @@ def create_update_confluence_page_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the update_confluence_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured update_confluence_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_confluence_page(
|
||||
page_title_or_id: str,
|
||||
|
|
@ -45,164 +63,168 @@ def create_update_confluence_page_tool(
|
|||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_update",
|
||||
tool_name="update_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_update",
|
||||
tool_name="update_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_title = result.params.get("new_title", new_title) or current_title
|
||||
final_content = result.params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = result.params.get("version", current_version)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_title = result.params.get("new_title", new_title) or current_title
|
||||
final_content = result.params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = result.params.get("version", current_version)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
else:
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -53,6 +53,23 @@ def create_get_connected_accounts_tool(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> StructuredTool:
|
||||
"""Factory function to create the get_connected_accounts tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to scope account discovery to.
|
||||
user_id: User ID to scope account discovery to.
|
||||
|
||||
Returns:
|
||||
Configured StructuredTool for connected-accounts discovery.
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
async def _run(service: str) -> list[dict[str, Any]]:
|
||||
svc_cfg = MCP_SERVICES.get(service)
|
||||
|
|
@ -68,40 +85,41 @@ def create_get_connected_accounts_tool(
|
|||
except ValueError:
|
||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return [
|
||||
{
|
||||
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
|
||||
if not connectors:
|
||||
return [
|
||||
{
|
||||
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
|
||||
}
|
||||
]
|
||||
|
||||
is_multi = len(connectors) > 1
|
||||
|
||||
accounts: list[dict[str, Any]] = []
|
||||
for conn in connectors:
|
||||
cfg = conn.config or {}
|
||||
entry: dict[str, Any] = {
|
||||
"connector_id": conn.id,
|
||||
"display_name": _extract_display_name(conn),
|
||||
"service": service,
|
||||
}
|
||||
]
|
||||
if is_multi:
|
||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||
for key in svc_cfg.account_metadata_keys:
|
||||
if key in cfg:
|
||||
entry[key] = cfg[key]
|
||||
accounts.append(entry)
|
||||
|
||||
is_multi = len(connectors) > 1
|
||||
|
||||
accounts: list[dict[str, Any]] = []
|
||||
for conn in connectors:
|
||||
cfg = conn.config or {}
|
||||
entry: dict[str, Any] = {
|
||||
"connector_id": conn.id,
|
||||
"display_name": _extract_display_name(conn),
|
||||
"service": service,
|
||||
}
|
||||
if is_multi:
|
||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||
for key in svc_cfg.account_metadata_keys:
|
||||
if key in cfg:
|
||||
entry[key] = cfg[key]
|
||||
accounts.append(entry)
|
||||
|
||||
return accounts
|
||||
return accounts
|
||||
|
||||
return StructuredTool(
|
||||
name="get_connected_accounts",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_list_discord_channels_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the list_discord_channels tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured list_discord_channels tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
|
@ -22,59 +41,60 @@ def create_list_discord_channels_tool(
|
|||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No guild ID in Discord connector config.",
|
||||
}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No guild ID in Discord connector config.",
|
||||
}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {
|
||||
"status": "success",
|
||||
"guild_id": guild_id,
|
||||
"channels": channels,
|
||||
"total": len(channels),
|
||||
}
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {
|
||||
"status": "success",
|
||||
"guild_id": guild_id,
|
||||
"channels": channels,
|
||||
"total": len(channels),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_read_discord_messages_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the read_discord_messages tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured read_discord_messages tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
|
|
@ -30,7 +49,7 @@ def create_read_discord_messages_tool(
|
|||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
|
|
@ -39,55 +58,56 @@ def create_read_discord_messages_tool(
|
|||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to read this channel.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
token = get_bot_token(connector)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to read this channel.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
|
|
@ -17,6 +18,23 @@ def create_send_discord_message_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the send_discord_message tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured send_discord_message tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
|
|
@ -34,7 +52,7 @@ def create_send_discord_message_tool(
|
|||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
|
|
@ -47,64 +65,65 @@ def create_send_discord_message_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Message was not sent.",
|
||||
}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to send messages in this channel.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Message was not sent.",
|
||||
}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to send messages in this channel.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,6 +59,23 @@ def create_create_dropbox_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_dropbox_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_dropbox_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_dropbox_file(
|
||||
name: str,
|
||||
|
|
@ -82,184 +99,191 @@ def create_create_dropbox_file_tool(
|
|||
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Dropbox accounts need re-authentication.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = DropboxClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_folder("")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{
|
||||
"folder_path": item.get("path_lower", ""),
|
||||
"name": item["name"],
|
||||
}
|
||||
for item in items
|
||||
if item.get(".tag") == "folder" and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s", cid, exc_info=True
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
"supported_types": _SUPPORTED_TYPES,
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_creation",
|
||||
tool_name="create_dropbox_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_path = result.params.get("parent_folder_path")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_extension(final_name, final_file_type)
|
||||
|
||||
if final_connector_id is not None:
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid.",
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Dropbox accounts need re-authentication.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = DropboxClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_folder("")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{
|
||||
"folder_path": item.get("path_lower", ""),
|
||||
"name": item["name"],
|
||||
}
|
||||
for item in items
|
||||
if item.get(".tag") == "folder" and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s",
|
||||
cid,
|
||||
exc_info=True,
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
"supported_types": _SUPPORTED_TYPES,
|
||||
}
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
parent_path = final_parent_folder_path or ""
|
||||
file_path = (
|
||||
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
||||
)
|
||||
|
||||
if final_file_type == "paper":
|
||||
created = await client.create_paper_doc(file_path, final_content or "")
|
||||
file_id = created.get("file_id", "")
|
||||
web_url = created.get("url", "")
|
||||
else:
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
created = await client.upload_file(
|
||||
file_path, docx_bytes, mode="add", autorename=True
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_creation",
|
||||
tool_name="create_dropbox_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
file_id = created.get("id", "")
|
||||
web_url = ""
|
||||
|
||||
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.dropbox import DropboxKBSyncService
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_path = result.params.get("parent_folder_path")
|
||||
|
||||
kb_service = DropboxKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=file_id,
|
||||
file_name=final_name,
|
||||
file_path=file_path,
|
||||
web_url=web_url,
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_extension(final_name, final_file_type)
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
connector = connectors[0]
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"name": final_name,
|
||||
"web_url": web_url,
|
||||
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
||||
}
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid.",
|
||||
}
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
parent_path = final_parent_folder_path or ""
|
||||
file_path = (
|
||||
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
||||
)
|
||||
|
||||
if final_file_type == "paper":
|
||||
created = await client.create_paper_doc(
|
||||
file_path, final_content or ""
|
||||
)
|
||||
file_id = created.get("file_id", "")
|
||||
web_url = created.get("url", "")
|
||||
else:
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
created = await client.upload_file(
|
||||
file_path, docx_bytes, mode="add", autorename=True
|
||||
)
|
||||
file_id = created.get("id", "")
|
||||
web_url = ""
|
||||
|
||||
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.dropbox import DropboxKBSyncService
|
||||
|
||||
kb_service = DropboxKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=file_id,
|
||||
file_name=final_name,
|
||||
file_path=file_path,
|
||||
web_url=web_url,
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"name": final_name,
|
||||
"web_url": web_url,
|
||||
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
|||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
async_session_maker,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +24,23 @@ def create_delete_dropbox_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_dropbox_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured delete_dropbox_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_dropbox_file(
|
||||
file_name: str,
|
||||
|
|
@ -55,33 +73,14 @@ def create_delete_dropbox_file_tool(
|
|||
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
async with async_session_maker() as db_session:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
|
|
@ -92,13 +91,7 @@ def create_delete_dropbox_file_tool(
|
|||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["dropbox_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -107,99 +100,63 @@ def create_delete_dropbox_file_tool(
|
|||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed Dropbox files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
file_path = meta.get("dropbox_path")
|
||||
file_id = meta.get("dropbox_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if not file_path:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File path is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
if not document:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["dropbox_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox connector not found or access denied.",
|
||||
}
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed Dropbox files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"file_path": file_path,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_trash",
|
||||
tool_name="delete_dropbox_file",
|
||||
params={
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
meta = document.document_metadata or {}
|
||||
file_path = meta.get("dropbox_path")
|
||||
file_id = meta.get("dropbox_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
if not file_path:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File path is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
final_file_path = result.params.get("file_path", file_path)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
|
|
@ -207,61 +164,128 @@ def create_delete_dropbox_file_tool(
|
|||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
||||
"message": "Dropbox connector not found or access denied.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||
)
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
|
||||
await client.delete_file(final_file_path)
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"file_path": file_path,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_trash",
|
||||
tool_name="delete_dropbox_file",
|
||||
params={
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
return trash_result
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_path = result.params.get("file_path", file_path)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id
|
||||
== search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
client = DropboxClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
await client.delete_file(final_file_path)
|
||||
|
||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
prefix = _resolve_provider_prefix(provider, custom_provider)
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
|
|
@ -146,14 +151,18 @@ def create_generate_image_tool(
|
|||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -175,14 +184,18 @@ def create_generate_image_tool(
|
|||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Any
|
||||
|
||||
from app.db import SearchSourceConnector
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
|
||||
def split_recipients(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
|
||||
|
||||
|
||||
def unwrap_composio_data(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner)
|
||||
return inner
|
||||
return data
|
||||
|
||||
|
||||
async def execute_composio_gmail_tool(
|
||||
connector: SearchSourceConnector,
|
||||
user_id: str,
|
||||
tool_name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Any, str | None]:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return None, "Composio connected account ID not found for this Gmail connector."
|
||||
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Gmail error")
|
||||
|
||||
return unwrap_composio_data(result.get("data")), None
|
||||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,23 @@ def create_create_gmail_draft_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_gmail_draft tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_gmail_draft tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_gmail_draft(
|
||||
to: str,
|
||||
|
|
@ -57,246 +75,276 @@ def create_create_gmail_draft_tool(
|
|||
"""
|
||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_gmail = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
)
|
||||
if is_composio_gmail:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
if is_composio_gmail:
|
||||
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||
execute_composio_gmail_tool,
|
||||
split_recipients,
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
|
||||
created, error = await execute_composio_gmail_tool(
|
||||
connector,
|
||||
user_id,
|
||||
"GMAIL_CREATE_EMAIL_DRAFT",
|
||||
{
|
||||
"user_id": "me",
|
||||
"recipient_email": final_to,
|
||||
"subject": final_subject,
|
||||
"body": final_body,
|
||||
"cc": split_recipients(final_cc),
|
||||
"bcc": split_recipients(final_bcc),
|
||||
"is_html": False,
|
||||
},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
if not isinstance(created, dict):
|
||||
created = {}
|
||||
else:
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -20,6 +20,23 @@ def create_read_gmail_email_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the read_gmail_email tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured read_gmail_email tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||
"""Read the full content of a specific Gmail email by its message ID.
|
||||
|
|
@ -32,60 +49,115 @@ def create_read_gmail_email_tool(
|
|||
Returns:
|
||||
Dictionary with status and the full email content formatted as markdown.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_format_gmail_summary,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
service = ComposioService()
|
||||
detail, error = await service.get_gmail_message_detail(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=message_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
if not detail:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
}
|
||||
|
||||
summary = _format_gmail_summary(detail)
|
||||
content = (
|
||||
f"# {summary['subject']}\n\n"
|
||||
f"**From:** {summary['from']}\n"
|
||||
f"**To:** {summary['to']}\n"
|
||||
f"**Date:** {summary['date']}\n\n"
|
||||
f"## Message Content\n\n"
|
||||
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||
f"## Message Details\n\n"
|
||||
f"- **Message ID:** {summary['message_id']}\n"
|
||||
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": summary["message_id"] or message_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_build_credentials,
|
||||
)
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
"status": "success",
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {"status": "success", "message_id": message_id, "content": content}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -39,12 +39,7 @@ def _build_credentials(connector: SearchSourceConnector):
|
|||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connectors must use Composio tool execution.")
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -67,11 +62,85 @@ def _build_credentials(connector: SearchSourceConnector):
|
|||
)
|
||||
|
||||
|
||||
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
|
||||
headers = message.get("payload", {}).get("headers", [])
|
||||
return {
|
||||
header.get("name", "").lower(): header.get("value", "")
|
||||
for header in headers
|
||||
if isinstance(header, dict)
|
||||
}
|
||||
|
||||
|
||||
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = _gmail_headers(message)
|
||||
return {
|
||||
"message_id": message.get("id") or message.get("messageId"),
|
||||
"thread_id": message.get("threadId"),
|
||||
"subject": message.get("subject") or headers.get("subject", "No Subject"),
|
||||
"from": message.get("sender") or headers.get("from", "Unknown"),
|
||||
"to": message.get("to") or headers.get("to", ""),
|
||||
"date": message.get("messageTimestamp") or headers.get("date", ""),
|
||||
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
|
||||
"labels": message.get("labelIds", []),
|
||||
}
|
||||
|
||||
|
||||
async def _search_composio_gmail(
|
||||
connector: SearchSourceConnector,
|
||||
user_id: str,
|
||||
query: str,
|
||||
max_results: int,
|
||||
) -> dict[str, Any]:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
service = ComposioService()
|
||||
messages, _next_token, _estimate, error = await service.get_gmail_messages(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
emails = [_format_gmail_summary(message) for message in messages]
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": emails,
|
||||
"total": len(emails),
|
||||
"message": "No emails found." if not emails else None,
|
||||
}
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the search_gmail tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured search_gmail tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def search_gmail(
|
||||
query: str,
|
||||
|
|
@ -90,83 +159,92 @@ def create_search_gmail_tool(
|
|||
Dictionary with status and a list of email summaries including
|
||||
message_id, subject, from, date, snippet.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 20)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": [],
|
||||
"total": 0,
|
||||
"message": "No emails found.",
|
||||
}
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
return await _search_composio_gmail(
|
||||
connector, str(user_id), query, max_results
|
||||
)
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append(
|
||||
{
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
}
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": [],
|
||||
"total": 0,
|
||||
"message": "No emails found.",
|
||||
}
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append(
|
||||
{
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,23 @@ def create_send_gmail_email_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the send_gmail_email tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured send_gmail_email tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def send_gmail_email(
|
||||
to: str,
|
||||
|
|
@ -58,247 +76,277 @@ def create_send_gmail_email_tool(
|
|||
"""
|
||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_gmail = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
)
|
||||
if is_composio_gmail:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
if is_composio_gmail:
|
||||
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||
execute_composio_gmail_tool,
|
||||
split_recipients,
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
|
||||
sent, error = await execute_composio_gmail_tool(
|
||||
connector,
|
||||
user_id,
|
||||
"GMAIL_SEND_EMAIL",
|
||||
{
|
||||
"user_id": "me",
|
||||
"recipient_email": final_to,
|
||||
"subject": final_subject,
|
||||
"body": final_body,
|
||||
"cc": split_recipients(final_cc),
|
||||
"bcc": split_recipients(final_bcc),
|
||||
"is_html": False,
|
||||
},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
if not isinstance(sent, dict):
|
||||
sent = {}
|
||||
else:
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,6 +18,23 @@ def create_trash_gmail_email_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the trash_gmail_email tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured trash_gmail_email tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def trash_gmail_email(
|
||||
email_subject_or_id: str,
|
||||
|
|
@ -55,244 +73,261 @@ def create_trash_gmail_email_tool(
|
|||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_trash",
|
||||
tool_name="trash_gmail_email",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_message_id = result.params.get("message_id", message_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_trash",
|
||||
tool_name="trash_gmail_email",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
final_message_id = result.params.get("message_id", message_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
return trash_result
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_gmail = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
)
|
||||
if is_composio_gmail:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_composio_gmail:
|
||||
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||
execute_composio_gmail_tool,
|
||||
)
|
||||
|
||||
_trashed, error = await execute_composio_gmail_tool(
|
||||
connector,
|
||||
user_id,
|
||||
"GMAIL_MOVE_TO_TRASH",
|
||||
{"user_id": "me", "message_id": final_message_id},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
else:
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,23 @@ def create_update_gmail_draft_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the update_gmail_draft tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured update_gmail_draft tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_gmail_draft(
|
||||
draft_subject_or_id: str,
|
||||
|
|
@ -76,294 +94,329 @@ def create_update_gmail_draft_tool(
|
|||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_update",
|
||||
tool_name="update_gmail_draft",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", final_to_default)
|
||||
final_subject = result.params.get("subject", final_subject_default)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_update",
|
||||
tool_name="update_gmail_draft",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", final_to_default)
|
||||
final_subject = result.params.get("subject", final_subject_default)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
from sqlalchemy.future import select
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
from app.db import Document
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_gmail = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
)
|
||||
if is_composio_gmail:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = (
|
||||
token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
if is_composio_gmail:
|
||||
final_draft_id = await _find_composio_draft_id_by_message(
|
||||
connector, user_id, message_id
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
if is_composio_gmail:
|
||||
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||
execute_composio_gmail_tool,
|
||||
split_recipients,
|
||||
)
|
||||
|
||||
updated, error = await execute_composio_gmail_tool(
|
||||
connector,
|
||||
user_id,
|
||||
"GMAIL_UPDATE_DRAFT",
|
||||
{
|
||||
"user_id": "me",
|
||||
"draft_id": final_draft_id,
|
||||
"recipient_email": final_to,
|
||||
"subject": final_subject,
|
||||
"body": final_body,
|
||||
"cc": split_recipients(final_cc),
|
||||
"bcc": split_recipients(final_bcc),
|
||||
"is_html": False,
|
||||
},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
if not isinstance(updated, dict):
|
||||
updated = {}
|
||||
else:
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
@ -408,3 +461,35 @@ async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _find_composio_draft_id_by_message(
|
||||
connector: Any, user_id: str, message_id: str
|
||||
) -> str | None:
|
||||
from app.agents.new_chat.tools.gmail.composio_helpers import (
|
||||
execute_composio_gmail_tool,
|
||||
)
|
||||
|
||||
page_token = ""
|
||||
while True:
|
||||
params: dict[str, Any] = {
|
||||
"user_id": "me",
|
||||
"max_results": 100,
|
||||
"verbose": False,
|
||||
}
|
||||
if page_token:
|
||||
params["page_token"] = page_token
|
||||
|
||||
data, error = await execute_composio_gmail_tool(
|
||||
connector, user_id, "GMAIL_LIST_DRAFTS", params
|
||||
)
|
||||
if error or not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
for draft in data.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft.get("id")
|
||||
|
||||
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||
if not page_token:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,23 @@ def create_create_calendar_event_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_calendar_event tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_calendar_event tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_calendar_event(
|
||||
summary: str,
|
||||
|
|
@ -60,254 +78,294 @@ def create_create_calendar_event_tool(
|
|||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_creation",
|
||||
tool_name="create_calendar_event",
|
||||
params={
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_start_datetime = result.params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||
final_description = result.params.get("description", description)
|
||||
final_location = result.params.get("location", location)
|
||||
final_attendees = result.params.get("attendees", attendees)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_creation",
|
||||
tool_name="create_calendar_event",
|
||||
params={
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_start_datetime = result.params.get(
|
||||
"start_datetime", start_datetime
|
||||
)
|
||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||
final_description = result.params.get("description", description)
|
||||
final_location = result.params.get("location", location)
|
||||
final_attendees = result.params.get("attendees", attendees)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event summary cannot be empty.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_calendar = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
if is_composio_calendar:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
if is_composio_calendar:
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
composio_params = {
|
||||
"calendar_id": "primary",
|
||||
"summary": final_summary,
|
||||
"start_datetime": final_start_datetime,
|
||||
"end_datetime": final_end_datetime,
|
||||
"timezone": tz,
|
||||
"attendees": final_attendees or [],
|
||||
}
|
||||
if final_description:
|
||||
composio_params["description"] = final_description
|
||||
if final_location:
|
||||
composio_params["location"] = final_location
|
||||
|
||||
composio_result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||
params=composio_params,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not composio_result.get("success"):
|
||||
raise RuntimeError(
|
||||
composio_result.get(
|
||||
"error", "Unknown Composio Calendar error"
|
||||
)
|
||||
)
|
||||
created = composio_result.get("data", {})
|
||||
if isinstance(created, dict):
|
||||
created = created.get("data", created)
|
||||
if isinstance(created, dict):
|
||||
created = created.get("response_data", created)
|
||||
else:
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,23 @@ def create_delete_calendar_event_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_calendar_event tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured delete_calendar_event tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_calendar_event(
|
||||
event_title_or_id: str,
|
||||
|
|
@ -54,240 +72,258 @@ def create_delete_calendar_event_tool(
|
|||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_deletion",
|
||||
tool_name="delete_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_deletion",
|
||||
tool_name="delete_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
return delete_result
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_calendar = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
if is_composio_calendar:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_composio_calendar:
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
composio_result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||
params={
|
||||
"calendar_id": "primary",
|
||||
"event_id": final_event_id,
|
||||
},
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not composio_result.get("success"):
|
||||
raise RuntimeError(
|
||||
composio_result.get(
|
||||
"error", "Unknown Composio Calendar error"
|
||||
)
|
||||
)
|
||||
else:
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return delete_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,11 +16,57 @@ _CALENDAR_TYPES = [
|
|||
]
|
||||
|
||||
|
||||
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||
if "T" in value:
|
||||
return value
|
||||
time = "23:59:59" if is_end else "00:00:00"
|
||||
return f"{value}T{time}Z"
|
||||
|
||||
|
||||
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append(
|
||||
{
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||
"status": ev.get("status", ""),
|
||||
}
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the search_calendar_events tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured search_calendar_events tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def search_calendar_events(
|
||||
start_date: str,
|
||||
|
|
@ -38,7 +84,7 @@ def create_search_calendar_events_tool(
|
|||
Dictionary with status and a list of events including
|
||||
event_id, summary, start, end, location, attendees.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Calendar tool not properly configured.",
|
||||
|
|
@ -47,76 +93,85 @@ def create_search_calendar_events_tool(
|
|||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
if "no events found" in error.lower():
|
||||
return {
|
||||
"status": "success",
|
||||
"events": [],
|
||||
"total": 0,
|
||||
"message": error,
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append(
|
||||
{
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||
"status": ev.get("status", ""),
|
||||
}
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
events_raw, error = await ComposioService().get_calendar_events(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||
max_results=max_results,
|
||||
)
|
||||
if not events_raw and not error:
|
||||
error = "No events found in the specified date range."
|
||||
else:
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import (
|
||||
GoogleCalendarConnector,
|
||||
)
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
if "no events found" in error.lower():
|
||||
return {
|
||||
"status": "success",
|
||||
"events": [],
|
||||
"total": 0,
|
||||
"message": error,
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
events = _format_calendar_events(events_raw)
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -33,6 +34,23 @@ def create_update_calendar_event_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the update_calendar_event tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured update_calendar_event tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_calendar_event(
|
||||
event_title_or_id: str,
|
||||
|
|
@ -74,272 +92,317 @@ def create_update_calendar_event_tool(
|
|||
"""
|
||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_update",
|
||||
tool_name="update_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = result.params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = result.params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = result.params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_location = result.params.get("new_location", new_location)
|
||||
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_update",
|
||||
tool_name="update_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
final_new_summary = result.params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = result.params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = result.params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_location = result.params.get("new_location", new_location)
|
||||
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
actual_connector_id = connector.id
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
is_composio_calendar = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
)
|
||||
if is_composio_calendar:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
)
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(
|
||||
final_new_end_datetime, context
|
||||
)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
if is_composio_calendar:
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
composio_params: dict[str, Any] = {
|
||||
"calendar_id": "primary",
|
||||
"event_id": final_event_id,
|
||||
}
|
||||
if final_new_summary is not None:
|
||||
composio_params["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
composio_params["start_time"] = final_new_start_datetime
|
||||
if final_new_end_datetime is not None:
|
||||
composio_params["end_time"] = final_new_end_datetime
|
||||
if final_new_description is not None:
|
||||
composio_params["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
composio_params["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
composio_params["attendees"] = [
|
||||
e.strip() for e in final_new_attendees if e.strip()
|
||||
]
|
||||
if not _is_date_only(
|
||||
final_new_start_datetime or final_new_end_datetime or ""
|
||||
):
|
||||
composio_params["timezone"] = context.get("timezone", "UTC")
|
||||
|
||||
composio_result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||
params=composio_params,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not composio_result.get("success"):
|
||||
raise RuntimeError(
|
||||
composio_result.get(
|
||||
"error", "Unknown Composio Calendar error"
|
||||
)
|
||||
)
|
||||
updated = composio_result.get("data", {})
|
||||
if isinstance(updated, dict):
|
||||
updated = updated.get("data", updated)
|
||||
if isinstance(updated, dict):
|
||||
updated = updated.get("response_data", updated)
|
||||
else:
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import (
|
||||
GoogleCalendarKBSyncService,
|
||||
)
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
from app.db import async_session_maker
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +24,25 @@ def create_create_google_drive_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_google_drive_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Google Drive connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
|
||||
Returns:
|
||||
Configured create_google_drive_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_google_drive_file(
|
||||
name: str,
|
||||
|
|
@ -65,7 +85,7 @@ def create_create_google_drive_file_tool(
|
|||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
|
|
@ -78,195 +98,232 @@ def create_create_google_drive_file_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_creation",
|
||||
tool_name="create_google_drive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
mime_type = _MIME_MAP.get(final_file_type)
|
||||
if not mime_type:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{final_file_type}'.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
"All Google Drive accounts have expired authentication"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_creation",
|
||||
tool_name="create_google_drive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
mime_type = _MIME_MAP.get(final_file_type)
|
||||
if not mime_type:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{final_file_type}'.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_drive = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
if is_composio_drive:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Drive connector.",
|
||||
}
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
try:
|
||||
if is_composio_drive:
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"mimeType": mime_type,
|
||||
"fields": "id,name,webViewLink,mimeType",
|
||||
}
|
||||
if final_parent_folder_id:
|
||||
params["parents"] = [final_parent_folder_id]
|
||||
if final_content:
|
||||
params["description"] = final_content[:4096]
|
||||
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLEDRIVE_CREATE_FILE",
|
||||
params=params,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(
|
||||
result.get("error", "Unknown Composio Drive error")
|
||||
)
|
||||
created = result.get("data", {})
|
||||
if isinstance(created, dict):
|
||||
created = created.get("data", created)
|
||||
if isinstance(created, dict):
|
||||
created = created.get("response_data", created)
|
||||
if not isinstance(created, dict):
|
||||
created = {}
|
||||
else:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.db import async_session_maker
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,6 +18,25 @@ def create_delete_google_drive_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_google_drive_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Google Drive connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
|
||||
Returns:
|
||||
Configured delete_google_drive_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_google_drive_file(
|
||||
file_name: str,
|
||||
|
|
@ -55,197 +75,214 @@ def create_delete_google_drive_file_tool(
|
|||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, file_name
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"File not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, file_name
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"File not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_trash",
|
||||
tool_name="delete_google_drive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||
}
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_trash",
|
||||
tool_name="delete_google_drive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
return trash_result
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
is_composio_drive = (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
if is_composio_drive:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Drive connector.",
|
||||
}
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
try:
|
||||
if is_composio_drive:
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||
params={"file_id": final_file_id},
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(
|
||||
result.get("error", "Unknown Composio Drive error")
|
||||
)
|
||||
else:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
|||
{
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_calendar_event",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,28 @@ def create_create_jira_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""Factory function to create the create_jira_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||
keep the request's outer transaction free of long-running Jira API
|
||||
blocking.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Jira connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured create_jira_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_jira_issue(
|
||||
project_key: str,
|
||||
|
|
@ -49,158 +72,167 @@ def create_create_jira_issue_tool(
|
|||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Jira accounts need re-authentication.",
|
||||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="jira_issue_creation",
|
||||
tool_name="create_jira_issue",
|
||||
params={
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_project_key = result.params.get("project_key", project_key)
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_issue_type = result.params.get("issue_type", issue_type)
|
||||
final_description = result.params.get("description", description)
|
||||
final_priority = result.params.get("priority", priority)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
||||
if not final_project_key:
|
||||
return {"status": "error", "message": "A project must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Jira connector found."}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Jira accounts need re-authentication.",
|
||||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="jira_issue_creation",
|
||||
tool_name="create_jira_issue",
|
||||
params={
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_project_key = result.params.get("project_key", project_key)
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_issue_type = result.params.get("issue_type", issue_type)
|
||||
final_description = result.params.get("description", description)
|
||||
final_priority = result.params.get("priority", priority)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
"message": "Issue summary cannot be empty.",
|
||||
}
|
||||
if not final_project_key:
|
||||
return {"status": "error", "message": "A project must be selected."}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
api_result = await asyncio.to_thread(
|
||||
jira_client.create_issue,
|
||||
project_key=final_project_key,
|
||||
summary=final_summary,
|
||||
issue_type=final_issue_type,
|
||||
description=final_description,
|
||||
priority=final_priority,
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
from sqlalchemy.future import select
|
||||
|
||||
issue_key = api_result.get("key", "")
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{issue_key}"
|
||||
if jira_history._base_url and issue_key
|
||||
else ""
|
||||
)
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=issue_key,
|
||||
issue_identifier=issue_key,
|
||||
issue_title=final_summary,
|
||||
description=final_description,
|
||||
state="To Do",
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Jira connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||
}
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
api_result = await asyncio.to_thread(
|
||||
jira_client.create_issue,
|
||||
project_key=final_project_key,
|
||||
summary=final_summary,
|
||||
issue_type=final_issue_type,
|
||||
description=final_description,
|
||||
priority=final_priority,
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_key = api_result.get("key", "")
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{issue_key}"
|
||||
if jira_history._base_url and issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=issue_key,
|
||||
issue_identifier=issue_key,
|
||||
issue_title=final_summary,
|
||||
description=final_description,
|
||||
state="To Do",
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,26 @@ def create_delete_jira_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""Factory function to create the delete_jira_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Jira connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured delete_jira_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
|
|
@ -44,130 +65,136 @@ def create_delete_jira_issue_tool(
|
|||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="jira_issue_deletion",
|
||||
tool_name="delete_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="jira_issue_deletion",
|
||||
tool_name="delete_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.orm.attributes import flag_modified
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,6 +20,26 @@ def create_update_jira_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""Factory function to create the update_jira_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Jira connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured update_jira_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
|
|
@ -48,169 +69,177 @@ def create_update_jira_issue_tool(
|
|||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="jira_issue_update",
|
||||
tool_name="update_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_summary = result.params.get("new_summary", new_summary)
|
||||
final_description = result.params.get("new_description", new_description)
|
||||
final_priority = result.params.get("new_priority", new_priority)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
if final_summary:
|
||||
fields["summary"] = final_summary
|
||||
if final_description is not None:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": final_description}],
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
],
|
||||
}
|
||||
if final_priority:
|
||||
fields["priority"] = {"name": final_priority}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not fields:
|
||||
return {"status": "error", "message": "No changes specified."}
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
result = request_approval(
|
||||
action_type="jira_issue_update",
|
||||
tool_name="update_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(
|
||||
jira_client.update_issue, final_issue_key, fields
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||
if jira_history._base_url and final_issue_key
|
||||
else ""
|
||||
)
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_summary = result.params.get("new_summary", new_summary)
|
||||
final_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_priority = result.params.get("new_priority", new_priority)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
from sqlalchemy.future import select
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_key,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
if final_summary:
|
||||
fields["summary"] = final_summary
|
||||
if final_description is not None:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [
|
||||
{"type": "text", "text": final_description}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
if final_priority:
|
||||
fields["priority"] = {"name": final_priority}
|
||||
|
||||
if not fields:
|
||||
return {"status": "error", "message": "No changes specified."}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(
|
||||
jira_client.update_issue, final_issue_key, fields
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||
if jira_history._base_url and final_issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_key,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
else:
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,11 +18,17 @@ def create_create_linear_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_linear_issue tool.
|
||||
"""Factory function to create the create_linear_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +36,7 @@ def create_create_linear_issue_tool(
|
|||
Returns:
|
||||
Configured create_linear_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_linear_issue(
|
||||
|
|
@ -65,7 +73,7 @@ def create_create_linear_issue_tool(
|
|||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -75,160 +83,170 @@ def create_create_linear_issue_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="linear_issue_creation",
|
||||
tool_name="create_linear_issue",
|
||||
params={
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_description = result.params.get("description", description)
|
||||
final_team_id = result.params.get("team_id")
|
||||
final_state_id = result.params.get("state_id")
|
||||
final_assignee_id = result.params.get("assignee_id")
|
||||
final_priority = result.params.get("priority")
|
||||
final_label_ids = result.params.get("label_ids") or []
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {"status": "error", "message": "Issue title cannot be empty."}
|
||||
if not final_team_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "A team must be selected to create an issue.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="linear_issue_creation",
|
||||
tool_name="create_linear_issue",
|
||||
params={
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_description = result.params.get("description", description)
|
||||
final_team_id = result.params.get("team_id")
|
||||
final_state_id = result.params.get("state_id")
|
||||
final_assignee_id = result.params.get("assignee_id")
|
||||
final_priority = result.params.get("priority")
|
||||
final_label_ids = result.params.get("label_ids") or []
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
||||
"message": "Issue title cannot be empty.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
if not final_team_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
"message": "A team must be selected to create an issue.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
|
||||
logger.info(
|
||||
f"Creating Linear issue with final params: title='{final_title}'"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
result = await linear_client.create_issue(
|
||||
team_id=final_team_id,
|
||||
title=final_title,
|
||||
description=final_description,
|
||||
state_id=final_state_id,
|
||||
assignee_id=final_assignee_id,
|
||||
priority=final_priority,
|
||||
label_ids=final_label_ids if final_label_ids else None,
|
||||
)
|
||||
from sqlalchemy.future import select
|
||||
|
||||
if result.get("status") == "error":
|
||||
logger.error(f"Failed to create Linear issue: {result.get('message')}")
|
||||
return {"status": "error", "message": result.get("message")}
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
logger.info(
|
||||
f"Creating Linear issue with final params: title='{final_title}'"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
result = await linear_client.create_issue(
|
||||
team_id=final_team_id,
|
||||
title=final_title,
|
||||
description=final_description,
|
||||
state_id=final_state_id,
|
||||
assignee_id=final_assignee_id,
|
||||
priority=final_priority,
|
||||
label_ids=final_label_ids if final_label_ids else None,
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to create Linear issue: {result.get('message')}"
|
||||
)
|
||||
return {"status": "error", "message": result.get("message")}
|
||||
|
||||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,11 +18,17 @@ def create_delete_linear_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_linear_issue tool.
|
||||
"""Factory function to create the delete_linear_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for finding the correct Linear connector
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +36,7 @@ def create_delete_linear_issue_tool(
|
|||
Returns:
|
||||
Configured delete_linear_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_linear_issue(
|
||||
|
|
@ -73,7 +81,7 @@ def create_delete_linear_issue_tool(
|
|||
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -83,149 +91,152 @@ def create_delete_linear_issue_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_id = context["issue"]["id"]
|
||||
issue_identifier = context["issue"].get("identifier", "")
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="linear_issue_deletion",
|
||||
tool_name="delete_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_id = context["issue"]["id"]
|
||||
issue_identifier = context["issue"].get("identifier", "")
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="linear_issue_deletion",
|
||||
tool_name="delete_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
|
||||
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
||||
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
||||
|
||||
logger.info(
|
||||
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
logger.info(
|
||||
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from app.db import Document
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if issue_identifier:
|
||||
result["message"] = (
|
||||
f"Issue {issue_identifier} archived successfully."
|
||||
)
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} Also removed from the knowledge base."
|
||||
)
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if issue_identifier:
|
||||
result["message"] = (
|
||||
f"Issue {issue_identifier} archived successfully."
|
||||
)
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} Also removed from the knowledge base."
|
||||
)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,11 +18,17 @@ def create_update_linear_issue_tool(
|
|||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the update_linear_issue tool.
|
||||
"""Factory function to create the update_linear_issue tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +36,7 @@ def create_update_linear_issue_tool(
|
|||
Returns:
|
||||
Configured update_linear_issue tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_linear_issue(
|
||||
|
|
@ -86,7 +94,7 @@ def create_update_linear_issue_tool(
|
|||
"""
|
||||
logger.info(f"update_linear_issue called: issue_ref='{issue_ref}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -96,176 +104,177 @@ def create_update_linear_issue_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_id = context["issue"]["id"]
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
team = context.get("team", {})
|
||||
new_state_id = _resolve_state(team, new_state_name)
|
||||
new_assignee_id = _resolve_assignee(team, new_assignee_email)
|
||||
new_label_ids = _resolve_labels(team, new_label_names)
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="linear_issue_update",
|
||||
tool_name="update_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_description": new_description,
|
||||
"new_state_id": new_state_id,
|
||||
"new_assignee_id": new_assignee_id,
|
||||
"new_priority": new_priority,
|
||||
"new_label_ids": new_label_ids,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
final_new_title = result.params.get("new_title", new_title)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_state_id = result.params.get("new_state_id", new_state_id)
|
||||
final_new_assignee_id = result.params.get(
|
||||
"new_assignee_id", new_assignee_id
|
||||
)
|
||||
final_new_priority = result.params.get("new_priority", new_priority)
|
||||
final_new_label_ids: list[str] | None = result.params.get(
|
||||
"new_label_ids", new_label_ids
|
||||
)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
if not final_connector_id:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={final_connector_id}")
|
||||
|
||||
logger.info(
|
||||
f"Updating Linear issue with final params: issue_id={final_issue_id}"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
updated_issue = await linear_client.update_issue(
|
||||
issue_id=final_issue_id,
|
||||
title=final_new_title,
|
||||
description=final_new_description,
|
||||
state_id=final_new_state_id,
|
||||
assignee_id=final_new_assignee_id,
|
||||
priority=final_new_priority,
|
||||
label_ids=final_new_label_ids,
|
||||
)
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if updated_issue.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to update Linear issue: {updated_issue.get('message')}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": updated_issue.get("message"),
|
||||
}
|
||||
issue_id = context["issue"]["id"]
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
|
||||
)
|
||||
team = context.get("team", {})
|
||||
new_state_id = _resolve_state(team, new_state_name)
|
||||
new_assignee_id = _resolve_assignee(team, new_assignee_email)
|
||||
new_label_ids = _resolve_labels(team, new_label_names)
|
||||
|
||||
if final_document_id is not None:
|
||||
logger.info(
|
||||
f"Updating knowledge base for document {final_document_id}..."
|
||||
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
|
||||
)
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
result = request_approval(
|
||||
action_type="linear_issue_update",
|
||||
tool_name="update_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_description": new_description,
|
||||
"new_state_id": new_state_id,
|
||||
"new_assignee_id": new_assignee_id,
|
||||
"new_priority": new_priority,
|
||||
"new_label_ids": new_label_ids,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
logger.info(
|
||||
f"Knowledge base successfully updated for issue {final_issue_id}"
|
||||
)
|
||||
kb_message = " Your knowledge base has also been updated."
|
||||
elif kb_result["status"] == "not_indexed":
|
||||
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
else:
|
||||
logger.warning(
|
||||
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
|
||||
)
|
||||
kb_message = " Your knowledge base will be updated in the next scheduled sync."
|
||||
else:
|
||||
kb_message = ""
|
||||
|
||||
identifier = updated_issue.get("identifier")
|
||||
default_msg = f"Issue {identifier} updated successfully."
|
||||
return {
|
||||
"status": "success",
|
||||
"identifier": identifier,
|
||||
"url": updated_issue.get("url"),
|
||||
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
|
||||
}
|
||||
if result.rejected:
|
||||
logger.info("Linear issue update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
final_new_title = result.params.get("new_title", new_title)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_state_id = result.params.get("new_state_id", new_state_id)
|
||||
final_new_assignee_id = result.params.get(
|
||||
"new_assignee_id", new_assignee_id
|
||||
)
|
||||
final_new_priority = result.params.get("new_priority", new_priority)
|
||||
final_new_label_ids: list[str] | None = result.params.get(
|
||||
"new_label_ids", new_label_ids
|
||||
)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
if not final_connector_id:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={final_connector_id}")
|
||||
|
||||
logger.info(
|
||||
f"Updating Linear issue with final params: issue_id={final_issue_id}"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
updated_issue = await linear_client.update_issue(
|
||||
issue_id=final_issue_id,
|
||||
title=final_new_title,
|
||||
description=final_new_description,
|
||||
state_id=final_new_state_id,
|
||||
assignee_id=final_new_assignee_id,
|
||||
priority=final_new_priority,
|
||||
label_ids=final_new_label_ids,
|
||||
)
|
||||
|
||||
if updated_issue.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to update Linear issue: {updated_issue.get('message')}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": updated_issue.get("message"),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"update_issue result: {updated_issue.get('identifier')} - {updated_issue.get('title')}"
|
||||
)
|
||||
|
||||
if final_document_id is not None:
|
||||
logger.info(
|
||||
f"Updating knowledge base for document {final_document_id}..."
|
||||
)
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
logger.info(
|
||||
f"Knowledge base successfully updated for issue {final_issue_id}"
|
||||
)
|
||||
kb_message = " Your knowledge base has also been updated."
|
||||
elif kb_result["status"] == "not_indexed":
|
||||
kb_message = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
else:
|
||||
logger.warning(
|
||||
f"KB update failed for issue {final_issue_id}: {kb_result.get('message')}"
|
||||
)
|
||||
kb_message = " Your knowledge base will be updated in the next scheduled sync."
|
||||
else:
|
||||
kb_message = ""
|
||||
|
||||
identifier = updated_issue.get("identifier")
|
||||
default_msg = f"Issue {identifier} updated successfully."
|
||||
return {
|
||||
"status": "success",
|
||||
"identifier": identifier,
|
||||
"url": updated_issue.get("url"),
|
||||
"message": f"{updated_issue.get('message', default_msg)}{kb_message}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
|
|
@ -17,6 +18,23 @@ def create_create_luma_event_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_luma_event tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_luma_event tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_luma_event(
|
||||
name: str,
|
||||
|
|
@ -40,83 +58,86 @@ def create_create_luma_event_tool(
|
|||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_luma_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="luma_create_event",
|
||||
tool_name="create_luma_event",
|
||||
params={
|
||||
"name": name,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"description": description,
|
||||
"timezone": timezone,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Event was not created.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_start = result.params.get("start_at", start_at)
|
||||
final_end = result.params.get("end_at", end_at)
|
||||
final_desc = result.params.get("description", description)
|
||||
final_tz = result.params.get("timezone", timezone)
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"start_at": final_start,
|
||||
"end_at": final_end,
|
||||
"timezone": final_tz,
|
||||
}
|
||||
if final_desc:
|
||||
body["description_md"] = final_desc
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LUMA_API}/event/create",
|
||||
headers=headers,
|
||||
json=body,
|
||||
result = request_approval(
|
||||
action_type="luma_create_event",
|
||||
tool_name="create_luma_event",
|
||||
params={
|
||||
"name": name,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"description": description,
|
||||
"timezone": timezone,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Luma Plus subscription required to create events via API.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Event was not created.",
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||
final_name = result.params.get("name", name)
|
||||
final_start = result.params.get("start_at", start_at)
|
||||
final_end = result.params.get("end_at", end_at)
|
||||
final_desc = result.params.get("description", description)
|
||||
final_tz = result.params.get("timezone", timezone)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": event_id,
|
||||
"message": f"Event '{final_name}' created on Luma.",
|
||||
}
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"start_at": final_start,
|
||||
"end_at": final_end,
|
||||
"timezone": final_tz,
|
||||
}
|
||||
if final_desc:
|
||||
body["description_md"] = final_desc
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LUMA_API}/event/create",
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Luma Plus subscription required to create events via API.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code} — {resp.text[:200]}",
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": event_id,
|
||||
"message": f"Event '{final_name}' created on Luma.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_list_luma_events_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the list_luma_events tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured list_luma_events tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def list_luma_events(
|
||||
max_results: int = 25,
|
||||
|
|
@ -28,77 +47,80 @@ def create_list_luma_events_tool(
|
|||
Dictionary with status and a list of events including
|
||||
event_id, name, start_at, end_at, location, url.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_luma_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
all_entries: list[dict] = []
|
||||
cursor = None
|
||||
all_entries: list[dict] = []
|
||||
cursor = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
while len(all_entries) < max_results:
|
||||
params: dict[str, Any] = {
|
||||
"limit": min(100, max_results - len(all_entries))
|
||||
}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
while len(all_entries) < max_results:
|
||||
params: dict[str, Any] = {
|
||||
"limit": min(100, max_results - len(all_entries))
|
||||
}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/calendar/list-events",
|
||||
headers=headers,
|
||||
params=params,
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/calendar/list-events",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
entries = data.get("entries", [])
|
||||
if not entries:
|
||||
break
|
||||
all_entries.extend(entries)
|
||||
|
||||
next_cursor = data.get("next_cursor")
|
||||
if not next_cursor:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
events = []
|
||||
for entry in all_entries[:max_results]:
|
||||
ev = entry.get("event", {})
|
||||
geo = ev.get("geo_info", {})
|
||||
events.append(
|
||||
{
|
||||
"event_id": entry.get("api_id"),
|
||||
"name": ev.get("name", "Untitled"),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location": geo.get("name", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
}
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
entries = data.get("entries", [])
|
||||
if not entries:
|
||||
break
|
||||
all_entries.extend(entries)
|
||||
|
||||
next_cursor = data.get("next_cursor")
|
||||
if not next_cursor:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
events = []
|
||||
for entry in all_entries[:max_results]:
|
||||
ev = entry.get("event", {})
|
||||
geo = ev.get("geo_info", {})
|
||||
events.append(
|
||||
{
|
||||
"event_id": entry.get("api_id"),
|
||||
"name": ev.get("name", "Untitled"),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location": geo.get("name", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_read_luma_event_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the read_luma_event tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured read_luma_event tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||
"""Read detailed information about a specific Luma event.
|
||||
|
|
@ -26,60 +45,63 @@ def create_read_luma_event_tool(
|
|||
Dictionary with status and full event details including
|
||||
description, attendees count, meeting URL.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/events/{event_id}",
|
||||
headers=headers,
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_luma_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code == 404:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Event '{event_id}' not found.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code}",
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/events/{event_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Luma API key is invalid.",
|
||||
"connector_type": "luma",
|
||||
}
|
||||
if resp.status_code == 404:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Event '{event_id}' not found.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Luma API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
ev = data.get("event", data)
|
||||
geo = ev.get("geo_info", {})
|
||||
|
||||
event_detail = {
|
||||
"event_id": event_id,
|
||||
"name": ev.get("name", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location_name": geo.get("name", ""),
|
||||
"address": geo.get("address", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"meeting_url": ev.get("meeting_url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
"cover_url": ev.get("cover_url", ""),
|
||||
}
|
||||
|
||||
data = resp.json()
|
||||
ev = data.get("event", data)
|
||||
geo = ev.get("geo_info", {})
|
||||
|
||||
event_detail = {
|
||||
"event_id": event_id,
|
||||
"name": ev.get("name", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location_name": geo.get("name", ""),
|
||||
"address": geo.get("address", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"meeting_url": ev.get("meeting_url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
"cover_url": ev.get("cover_url", ""),
|
||||
}
|
||||
|
||||
return {"status": "success", "event": event_detail}
|
||||
return {"status": "success", "event": event_detail}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.notion import NotionToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,8 +21,17 @@ def create_create_notion_page_tool(
|
|||
"""
|
||||
Factory function to create the create_notion_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker`. This is critical for the compiled-agent
|
||||
cache: the compiled graph (and therefore this closure) is reused
|
||||
across HTTP requests, so capturing a per-request session here would
|
||||
surface stale/closed sessions on cache hits. Per-call sessions also
|
||||
keep the request's outer transaction free of long-running Notion API
|
||||
blocking.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing Notion connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Notion connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +39,7 @@ def create_create_notion_page_tool(
|
|||
Returns:
|
||||
Configured create_notion_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_notion_page(
|
||||
|
|
@ -67,7 +78,7 @@ def create_create_notion_page_tool(
|
|||
"""
|
||||
logger.info(f"create_notion_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Notion tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -77,154 +88,157 @@ def create_create_notion_page_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": context["error"],
|
||||
}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Notion accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "notion",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="notion_page_creation",
|
||||
tool_name="create_notion_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"parent_page_id": None,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Notion page creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content)
|
||||
final_parent_page_id = result.params.get("parent_page_id")
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Page title cannot be empty. Please provide a valid title.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Creating Notion page with final params: title='{final_title}'"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.warning(
|
||||
f"No Notion connector found for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
if "error" in context:
|
||||
logger.error(
|
||||
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
|
||||
f"Failed to fetch creation context: {context['error']}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
"message": context["error"],
|
||||
}
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Notion accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "notion",
|
||||
}
|
||||
|
||||
result = await notion_connector.create_page(
|
||||
title=final_title,
|
||||
content=final_content,
|
||||
parent_page_id=final_parent_page_id,
|
||||
)
|
||||
logger.info(
|
||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="notion_page_creation",
|
||||
tool_name="create_notion_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"parent_page_id": None,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
if result.rejected:
|
||||
logger.info("Notion page creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=result.get("page_id"),
|
||||
page_title=result.get("title", final_title),
|
||||
page_url=result.get("url"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content)
|
||||
final_parent_page_id = result.params.get("parent_page_id")
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Page title cannot be empty. Please provide a valid title.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Creating Notion page with final params: title='{final_title}'"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
else:
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.warning(
|
||||
f"No Notion connector found for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
}
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
|
||||
result = await notion_connector.create_page(
|
||||
title=final_title,
|
||||
content=final_content,
|
||||
parent_page_id=final_parent_page_id,
|
||||
)
|
||||
logger.info(
|
||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=result.get("page_id"),
|
||||
page_title=result.get("title", final_title),
|
||||
page_url=result.get("url"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
result["message"] = result.get("message", "") + kb_message_suffix
|
||||
result["message"] = result.get("message", "") + kb_message_suffix
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,8 +21,14 @@ def create_delete_notion_page_tool(
|
|||
"""
|
||||
Factory function to create the delete_notion_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing Notion connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Notion connector
|
||||
user_id: User ID for finding the correct Notion connector
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +36,7 @@ def create_delete_notion_page_tool(
|
|||
Returns:
|
||||
Configured delete_notion_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_notion_page(
|
||||
|
|
@ -63,7 +71,7 @@ def create_delete_notion_page_tool(
|
|||
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Notion tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -73,164 +81,167 @@ def create_delete_notion_page_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
# Get page context (page_id, account, title) from indexed data
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, page_title
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
# Check if it's a "not found" error (softer handling for LLM)
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Page not found: {error_msg}")
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": error_msg,
|
||||
}
|
||||
else:
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
async with async_session_maker() as db_session:
|
||||
# Get page context (page_id, account, title) from indexed data
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, page_title
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
connector_id_from_context = account.get("id")
|
||||
document_id = context.get("document_id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
|
||||
result = request_approval(
|
||||
action_type="notion_page_deletion",
|
||||
tool_name="delete_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Notion page deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
# Validate the connector
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this page")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
# Create connector instance
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
|
||||
# Delete the page from Notion
|
||||
result = await notion_connector.delete_page(page_id=final_page_id)
|
||||
logger.info(
|
||||
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
# If deletion was successful and user wants to delete from KB
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Document
|
||||
|
||||
# Get the document
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
# Check if it's a "not found" error (softer handling for LLM)
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Page not found: {error_msg}")
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": error_msg,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
}
|
||||
|
||||
# Update result with KB deletion status
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} (also removed from knowledge base)"
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
return result
|
||||
page_id = context.get("page_id")
|
||||
connector_id_from_context = account.get("id")
|
||||
document_id = context.get("document_id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
|
||||
result = request_approval(
|
||||
action_type="notion_page_deletion",
|
||||
tool_name="delete_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Notion page deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
# Validate the connector
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this page")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
# Create connector instance
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
|
||||
# Delete the page from Notion
|
||||
result = await notion_connector.delete_page(page_id=final_page_id)
|
||||
logger.info(
|
||||
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
# If deletion was successful and user wants to delete from KB
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Document
|
||||
|
||||
# Get the document
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
# Update result with KB deletion status
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.db import async_session_maker
|
||||
from app.services.notion import NotionToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,8 +21,14 @@ def create_update_notion_page_tool(
|
|||
"""
|
||||
Factory function to create the update_notion_page tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache (see
|
||||
``create_create_notion_page_tool`` for the full rationale).
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing Notion connector
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
search_space_id: Search space ID to find the Notion connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
|
@ -29,6 +36,7 @@ def create_update_notion_page_tool(
|
|||
Returns:
|
||||
Configured update_notion_page tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_notion_page(
|
||||
|
|
@ -71,7 +79,7 @@ def create_update_notion_page_tool(
|
|||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Notion tool not properly configured - missing required parameters"
|
||||
)
|
||||
|
|
@ -88,152 +96,155 @@ def create_update_notion_page_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
# Check if it's a "not found" error (softer handling for LLM)
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Page not found: {error_msg}")
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": error_msg,
|
||||
}
|
||||
else:
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
document_id = context.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="notion_page_update",
|
||||
tool_name="update_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"content": content,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Notion page update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this page")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
|
||||
result = await notion_connector.update_page(
|
||||
page_id=final_page_id,
|
||||
content=final_content,
|
||||
)
|
||||
logger.info(
|
||||
f"update_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success" and document_id is not None:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
logger.info(f"Updating knowledge base for document {document_id}...")
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
appended_content=final_content,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
appended_block_ids=result.get("appended_block_ids"),
|
||||
async with async_session_maker() as db_session:
|
||||
metadata_service = NotionToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title
|
||||
)
|
||||
|
||||
if kb_result["status"] == "success":
|
||||
result["message"] = (
|
||||
f"{result['message']}. Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"Knowledge base successfully updated for page {final_page_id}"
|
||||
)
|
||||
elif kb_result["status"] == "not_indexed":
|
||||
result["message"] = (
|
||||
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
|
||||
)
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
# Check if it's a "not found" error (softer handling for LLM)
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Page not found: {error_msg}")
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": error_msg,
|
||||
}
|
||||
else:
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
f"KB update failed for page {final_page_id}: {kb_result['message']}"
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
document_id = context.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="notion_page_update",
|
||||
tool_name="update_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"content": content,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Notion page update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Notion connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this page")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
notion_connector = NotionHistoryConnector(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
|
||||
result = await notion_connector.update_page(
|
||||
page_id=final_page_id,
|
||||
content=final_content,
|
||||
)
|
||||
logger.info(
|
||||
f"update_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success" and document_id is not None:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
logger.info(
|
||||
f"Updating knowledge base for document {document_id}..."
|
||||
)
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
appended_content=final_content,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
appended_block_ids=result.get("appended_block_ids"),
|
||||
)
|
||||
|
||||
return result
|
||||
if kb_result["status"] == "success":
|
||||
result["message"] = (
|
||||
f"{result['message']}. Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"Knowledge base successfully updated for page {final_page_id}"
|
||||
)
|
||||
elif kb_result["status"] == "not_indexed":
|
||||
result["message"] = (
|
||||
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
|
||||
)
|
||||
logger.warning(
|
||||
f"KB update failed for page {final_page_id}: {kb_result['message']}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.onedrive.client import OneDriveClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -48,6 +48,23 @@ def create_create_onedrive_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_onedrive_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured create_onedrive_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def create_onedrive_file(
|
||||
name: str,
|
||||
|
|
@ -70,173 +87,178 @@ def create_create_onedrive_file_tool(
|
|||
"""
|
||||
logger.info(f"create_onedrive_file called: name='{name}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected OneDrive accounts need re-authentication.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = OneDriveClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_children("root")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{"folder_id": item["id"], "name": item["name"]}
|
||||
for item in items
|
||||
if item.get("folder") is not None
|
||||
and item.get("id")
|
||||
and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s", cid, exc_info=True
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_creation",
|
||||
tool_name="create_onedrive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_docx_extension(final_name)
|
||||
|
||||
if final_connector_id is not None:
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid.",
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected OneDrive accounts need re-authentication.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = OneDriveClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_children("root")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{"folder_id": item["id"], "name": item["name"]}
|
||||
for item in items
|
||||
if item.get("folder") is not None
|
||||
and item.get("id")
|
||||
and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s",
|
||||
cid,
|
||||
exc_info=True,
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
|
||||
client = OneDriveClient(session=db_session, connector_id=connector.id)
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
parent_id=final_parent_folder_id,
|
||||
content=docx_bytes,
|
||||
mime_type=DOCX_MIME,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.onedrive import OneDriveKBSyncService
|
||||
|
||||
kb_service = OneDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=DOCX_MIME,
|
||||
web_url=created.get("webUrl"),
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_creation",
|
||||
tool_name="create_onedrive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_url": created.get("webUrl"),
|
||||
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_docx_extension(final_name)
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid.",
|
||||
}
|
||||
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
|
||||
client = OneDriveClient(session=db_session, connector_id=connector.id)
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
parent_id=final_parent_folder_id,
|
||||
content=docx_bytes,
|
||||
mime_type=DOCX_MIME,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.onedrive import OneDriveKBSyncService
|
||||
|
||||
kb_service = OneDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=DOCX_MIME,
|
||||
web_url=created.get("webUrl"),
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_url": created.get("webUrl"),
|
||||
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.db import (
|
|||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
async_session_maker,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +24,23 @@ def create_delete_onedrive_file_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_onedrive_file tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured delete_onedrive_file tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def delete_onedrive_file(
|
||||
file_name: str,
|
||||
|
|
@ -56,33 +74,14 @@ def create_delete_onedrive_file_tool(
|
|||
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
async with async_session_maker() as db_session:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
|
|
@ -93,13 +92,7 @@ def create_delete_onedrive_file_tool(
|
|||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["onedrive_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -108,98 +101,64 @@ def create_delete_onedrive_file_tool(
|
|||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed OneDrive files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
file_id = meta.get("onedrive_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
if not document:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata[
|
||||
"onedrive_file_name"
|
||||
],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive connector not found or access denied.",
|
||||
}
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed OneDrive files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
"web_url": meta.get("web_url"),
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_trash",
|
||||
tool_name="delete_onedrive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
meta = document.document_metadata or {}
|
||||
file_id = meta.get("onedrive_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
|
|
@ -207,65 +166,130 @@ def create_delete_onedrive_file_tool(
|
|||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid or has been disconnected.",
|
||||
"message": "OneDrive connector not found or access denied.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
|
||||
)
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
|
||||
client = OneDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
await client.trash_file(final_file_id)
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
"web_url": meta.get("web_url"),
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file_name}' to the recycle bin.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_trash",
|
||||
tool_name="delete_onedrive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
return trash_result
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get(
|
||||
"delete_from_kb", delete_from_kb
|
||||
)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id
|
||||
== search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
client = OneDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
await client.trash_file(final_file_id)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file_name}' to the recycle bin.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -824,13 +824,22 @@ async def build_tools_async(
|
|||
"""Async version of build_tools that also loads MCP tools from database.
|
||||
|
||||
Design Note:
|
||||
This function exists because MCP tools require database queries to load user configs,
|
||||
while built-in tools are created synchronously from static code.
|
||||
This function exists because MCP tools require database queries to load
|
||||
user configs, while built-in tools are created synchronously from static
|
||||
code.
|
||||
|
||||
Alternative: We could make build_tools() itself async and always query the database,
|
||||
but that would force async everywhere even when only using built-in tools. The current
|
||||
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
||||
database-loaded tools through this async wrapper.
|
||||
Alternative: We could make build_tools() itself async and always query
|
||||
the database, but that would force async everywhere even when only using
|
||||
built-in tools. The current design keeps the simple case (static tools
|
||||
only) synchronous while supporting dynamic database-loaded tools through
|
||||
this async wrapper.
|
||||
|
||||
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
|
||||
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
|
||||
the event loop) are kicked off concurrently. Cold-path savings are
|
||||
bounded by the slower of the two — typically MCP at ~200ms-1.7s —
|
||||
so the parallelization recovers the ~50-200ms previously spent
|
||||
serially on built-in construction.
|
||||
|
||||
Args:
|
||||
dependencies: Dict containing all possible dependencies
|
||||
|
|
@ -843,33 +852,70 @@ async def build_tools_async(
|
|||
List of configured tool instances ready for the agent, including MCP tools.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
|
||||
can_load_mcp = (
|
||||
include_mcp_tools
|
||||
and "db_session" in dependencies
|
||||
and "search_space_id" in dependencies
|
||||
)
|
||||
|
||||
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
|
||||
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
|
||||
# function over its inputs — safe to thread-shift.
|
||||
_t0 = time.perf_counter()
|
||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||
builtin_task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
|
||||
)
|
||||
)
|
||||
|
||||
mcp_task: asyncio.Task | None = None
|
||||
if can_load_mcp:
|
||||
mcp_task = asyncio.create_task(
|
||||
load_mcp_tools(
|
||||
dependencies["db_session"],
|
||||
dependencies["search_space_id"],
|
||||
)
|
||||
)
|
||||
|
||||
# Surface failures from each task independently so a flaky MCP
|
||||
# endpoint never poisons built-in tool registration. ``return_exceptions``
|
||||
# gives us per-task exceptions instead of dropping the second result
|
||||
# when the first raises.
|
||||
if mcp_task is not None:
|
||||
builtin_result, mcp_result = await asyncio.gather(
|
||||
builtin_task, mcp_task, return_exceptions=True
|
||||
)
|
||||
else:
|
||||
builtin_result = await builtin_task
|
||||
mcp_result = None
|
||||
|
||||
if isinstance(builtin_result, BaseException):
|
||||
raise builtin_result # built-in registration failure is non-recoverable
|
||||
tools: list[BaseTool] = builtin_result
|
||||
_perf_log.info(
|
||||
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
||||
"[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
|
||||
time.perf_counter() - _t0,
|
||||
len(tools),
|
||||
)
|
||||
|
||||
# Load MCP tools if requested and dependencies are available
|
||||
if (
|
||||
include_mcp_tools
|
||||
and "db_session" in dependencies
|
||||
and "search_space_id" in dependencies
|
||||
):
|
||||
try:
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"],
|
||||
dependencies["search_space_id"],
|
||||
if mcp_task is not None:
|
||||
if isinstance(mcp_result, BaseException):
|
||||
# ``return_exceptions=True`` captures the exception out-of-band,
|
||||
# so ``sys.exc_info()`` is empty here. Pass the captured
|
||||
# exception via ``exc_info=`` to get a real traceback.
|
||||
logging.error(
|
||||
"Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
|
||||
)
|
||||
else:
|
||||
mcp_tools = mcp_result or []
|
||||
_perf_log.info(
|
||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
|
||||
time.perf_counter() - _t0,
|
||||
len(mcp_tools),
|
||||
)
|
||||
|
|
@ -879,8 +925,6 @@ async def build_tools_async(
|
|||
len(mcp_tools),
|
||||
[t.name for t in mcp_tools],
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Failed to load MCP tools: %s", e)
|
||||
|
||||
logging.info(
|
||||
"Total tools for agent: %d — %s",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
|
||||
|
|
@ -124,12 +124,19 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
|||
"""
|
||||
Factory function to create the search_surfsense_docs tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Database session for executing queries
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
A configured tool function for searching Surfsense documentation
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
||||
|
|
@ -155,10 +162,11 @@ def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
|||
Returns:
|
||||
Relevant documentation content formatted with chunk IDs for citations
|
||||
"""
|
||||
return await search_surfsense_docs_async(
|
||||
query=query,
|
||||
db_session=db_session,
|
||||
top_k=top_k,
|
||||
)
|
||||
async with async_session_maker() as db_session:
|
||||
return await search_surfsense_docs_async(
|
||||
query=query,
|
||||
db_session=db_session,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return search_surfsense_docs
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_list_teams_channels_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the list_teams_channels tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured list_teams_channels tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def list_teams_channels() -> dict[str, Any]:
|
||||
"""List all Microsoft Teams and their channels the user has access to.
|
||||
|
|
@ -23,63 +42,66 @@ def create_list_teams_channels_tool(
|
|||
Dictionary with status and a list of teams, each containing
|
||||
team_id, team_name, and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
teams_resp = await client.get(
|
||||
f"{GRAPH_API}/me/joinedTeams", headers=headers
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_teams_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
if teams_resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if teams_resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {teams_resp.status_code}",
|
||||
}
|
||||
token = await get_access_token(db_session, connector)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
teams_data = teams_resp.json().get("value", [])
|
||||
result_teams = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for team in teams_data:
|
||||
team_id = team["id"]
|
||||
ch_resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||
headers=headers,
|
||||
)
|
||||
channels = []
|
||||
if ch_resp.status_code == 200:
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||
for ch in ch_resp.json().get("value", [])
|
||||
]
|
||||
result_teams.append(
|
||||
{
|
||||
"team_id": team_id,
|
||||
"team_name": team.get("displayName", ""),
|
||||
"channels": channels,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
teams_resp = await client.get(
|
||||
f"{GRAPH_API}/me/joinedTeams", headers=headers
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"teams": result_teams,
|
||||
"total_teams": len(result_teams),
|
||||
}
|
||||
if teams_resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if teams_resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {teams_resp.status_code}",
|
||||
}
|
||||
|
||||
teams_data = teams_resp.json().get("value", [])
|
||||
result_teams = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for team in teams_data:
|
||||
team_id = team["id"]
|
||||
ch_resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||
headers=headers,
|
||||
)
|
||||
channels = []
|
||||
if ch_resp.status_code == 200:
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||
for ch in ch_resp.json().get("value", [])
|
||||
]
|
||||
result_teams.append(
|
||||
{
|
||||
"team_id": team_id,
|
||||
"team_name": team.get("displayName", ""),
|
||||
"channels": channels,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"teams": result_teams,
|
||||
"total_teams": len(result_teams),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import httpx
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -15,6 +17,23 @@ def create_read_teams_messages_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the read_teams_messages tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured read_teams_messages tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def read_teams_messages(
|
||||
team_id: str,
|
||||
|
|
@ -32,65 +51,68 @@ def create_read_teams_messages_tool(
|
|||
Dictionary with status and a list of messages including
|
||||
id, sender, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={"$top": limit},
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_teams_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Insufficient permissions to read this channel.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {resp.status_code}",
|
||||
}
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
raw_msgs = resp.json().get("value", [])
|
||||
messages = []
|
||||
for m in raw_msgs:
|
||||
sender = m.get("from", {})
|
||||
user_info = sender.get("user", {}) if sender else {}
|
||||
body = m.get("body", {})
|
||||
messages.append(
|
||||
{
|
||||
"id": m.get("id"),
|
||||
"sender": user_info.get("displayName", "Unknown"),
|
||||
"content": body.get("content", ""),
|
||||
"content_type": body.get("contentType", "text"),
|
||||
"timestamp": m.get("createdDateTime", ""),
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={"$top": limit},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Insufficient permissions to read this channel.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {resp.status_code}",
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
raw_msgs = resp.json().get("value", [])
|
||||
messages = []
|
||||
for m in raw_msgs:
|
||||
sender = m.get("from", {})
|
||||
user_info = sender.get("user", {}) if sender else {}
|
||||
body = m.get("body", {})
|
||||
messages.append(
|
||||
{
|
||||
"id": m.get("id"),
|
||||
"sender": user_info.get("displayName", "Unknown"),
|
||||
"content": body.get("content", ""),
|
||||
"content_type": body.get("contentType", "text"),
|
||||
"timestamp": m.get("createdDateTime", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.db import async_session_maker
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
|
|
@ -17,6 +18,23 @@ def create_send_teams_message_tool(
|
|||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the send_teams_message tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
|
||||
Args:
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
|
||||
Returns:
|
||||
Configured send_teams_message tool
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def send_teams_message(
|
||||
team_id: str,
|
||||
|
|
@ -39,70 +57,73 @@ def create_send_teams_message_tool(
|
|||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
if search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
async with async_session_maker() as db_session:
|
||||
connector = await get_teams_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="teams_send_message",
|
||||
tool_name="send_teams_message",
|
||||
params={
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"content": content,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Message was not sent.",
|
||||
}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_team = result.params.get("team_id", team_id)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
result = request_approval(
|
||||
action_type="teams_send_message",
|
||||
tool_name="send_teams_message",
|
||||
params={
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"content": content,
|
||||
},
|
||||
json={"body": {"content": final_content}},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
|
||||
}
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Message was not sent.",
|
||||
}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": "Message sent to Teams channel.",
|
||||
}
|
||||
final_content = result.params.get("content", content)
|
||||
final_team = result.params.get("team_id", team_id)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"body": {"content": final_content}},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Teams token expired. Please re-authenticate.",
|
||||
"connector_type": "teams",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Graph API error: {resp.status_code} — {resp.text[:200]}",
|
||||
}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": "Message sent to Teams channel.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
from app.db import SearchSpace, User, async_session_maker
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -302,6 +302,25 @@ def create_update_memory_tool(
|
|||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
"""Factory function to create the user-memory update tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
The session's bound ``commit``/``rollback`` methods are captured at
|
||||
call time, after ``async with`` has bound ``db_session`` locally.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user whose memory document is being updated.
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
llm: Optional LLM for the forced-rewrite path.
|
||||
|
||||
Returns:
|
||||
Configured update_memory tool for the user-memory scope.
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
|
|
@ -318,26 +337,26 @@ def create_update_memory_tool(
|
|||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
|
|
@ -351,6 +370,27 @@ def create_update_team_memory_tool(
|
|||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
"""Factory function to create the team-memory update tool.
|
||||
|
||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
||||
:data:`async_session_maker` so the closure is safe to share across
|
||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
||||
session here would surface stale/closed sessions on cache hits.
|
||||
The session's bound ``commit``/``rollback`` methods are captured at
|
||||
call time, after ``async with`` has bound ``db_session`` locally.
|
||||
|
||||
Args:
|
||||
search_space_id: ID of the search space whose team memory is being
|
||||
updated.
|
||||
db_session: Reserved for registry compatibility. Per-call sessions
|
||||
are opened via :data:`async_session_maker` inside the tool body.
|
||||
llm: Optional LLM for the forced-rewrite path.
|
||||
|
||||
Returns:
|
||||
Configured update_memory tool for the team-memory scope.
|
||||
"""
|
||||
del db_session # per-call session — see docstring
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
|
@ -366,28 +406,30 @@ def create_update_team_memory_tool(
|
|||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
async with async_session_maker() as db_session:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(
|
||||
space, "shared_memory_md", content
|
||||
),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update team memory: {e}",
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from app.config import (
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
|
|
@ -420,6 +421,135 @@ def _stop_openrouter_background_refresh() -> None:
|
|||
OpenRouterIntegrationService.get_instance().stop_background_refresh()
|
||||
|
||||
|
||||
async def _warm_agent_jit_caches() -> None:
|
||||
"""Pay the LangChain / LangGraph / Deepagents JIT cost at startup.
|
||||
|
||||
Why
|
||||
----
|
||||
A cold ``create_agent`` + ``StateGraph.compile()`` + Pydantic schema
|
||||
generation chain takes 1.5-2 seconds of pure CPU on first invocation
|
||||
inside any Python process: the graph compiler builds reducers,
|
||||
Pydantic v2 generates and JITs validator schemas, deepagents
|
||||
eagerly compiles its general-purpose subagent, etc. Subsequent
|
||||
compiles in the same process pay only ~50% of that cost (the lazy
|
||||
JIT bits are cached in module-level dicts).
|
||||
|
||||
Doing one throwaway compile during ``lifespan`` startup pre-pays
|
||||
that cost so the *first real request* doesn't. We do NOT prime
|
||||
:mod:`agent_cache` because the cache key requires real
|
||||
``thread_id`` / ``user_id`` / ``search_space_id`` / etc. — the
|
||||
throwaway agent is genuinely thrown away and immediately collected.
|
||||
|
||||
Safety
|
||||
------
|
||||
* No DB access. We construct a stub LLM (no real keys), pass an
|
||||
empty tools list, and pass ``checkpointer=None`` so we never
|
||||
touch Postgres.
|
||||
* Bounded by ``asyncio.wait_for`` so a hang here can never block
|
||||
worker startup. On any failure, we log + swallow — the worst
|
||||
case is the first real request pays the full cold cost (i.e.
|
||||
pre-warmup behaviour).
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
t0 = _time.perf_counter()
|
||||
try:
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
ModelCallLimitMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModel,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
||||
# Minimal LLM stub. ``FakeListChatModel`` satisfies
|
||||
# ``BaseChatModel`` without any network or auth — perfect for
|
||||
# exercising the compile path without side effects.
|
||||
stub_llm = FakeListChatModel(responses=["warmup-response"])
|
||||
|
||||
# Two trivial tools with arg + return schemas — exercises the
|
||||
# Pydantic v2 schema JIT path. Without at least one tool the
|
||||
# graph compile skips the tool-loop bytecode generation that
|
||||
# accounts for ~30-50% of cold compile cost.
|
||||
@tool
|
||||
def _warmup_tool_a(query: str, limit: int = 5) -> str:
|
||||
"""Warmup tool A — never actually invoked."""
|
||||
return query[:limit]
|
||||
|
||||
@tool
|
||||
def _warmup_tool_b(name: str, value: float | None = None) -> dict[str, object]:
|
||||
"""Warmup tool B — never actually invoked."""
|
||||
return {"name": name, "value": value}
|
||||
|
||||
# A handful of common middleware so the compile pre-pays the
|
||||
# ``AgentMiddleware`` resolver path. These instances never run
|
||||
# because the throwaway agent is immediately collected.
|
||||
# ``SubAgentMiddleware`` is the single heaviest line in cold
|
||||
# ``create_surfsense_deep_agent`` (1.5-2s of CPU per call to
|
||||
# compile its general-purpose subagent's full inner graph),
|
||||
# so we include it here to make sure that compile path is JIT'd.
|
||||
warmup_middleware: list = [
|
||||
TodoListMiddleware(),
|
||||
ModelCallLimitMiddleware(
|
||||
thread_limit=120, run_limit=80, exit_behavior="end"
|
||||
),
|
||||
ToolCallLimitMiddleware(
|
||||
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||
),
|
||||
]
|
||||
try:
|
||||
from deepagents import SubAgentMiddleware
|
||||
from deepagents.backends import StateBackend
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
|
||||
gp_warmup_spec = { # type: ignore[var-annotated]
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": stub_llm,
|
||||
"tools": [_warmup_tool_a],
|
||||
"middleware": [TodoListMiddleware()],
|
||||
}
|
||||
warmup_middleware.append(
|
||||
SubAgentMiddleware(backend=StateBackend, subagents=[gp_warmup_spec])
|
||||
)
|
||||
except Exception:
|
||||
# Deepagents missing/incompatible — middleware-only warmup
|
||||
# still produces a useful (smaller) speedup.
|
||||
logger.debug("[startup] SubAgentMiddleware warmup skipped", exc_info=True)
|
||||
|
||||
compiled = create_agent(
|
||||
stub_llm,
|
||||
tools=[_warmup_tool_a, _warmup_tool_b],
|
||||
system_prompt="You are a warmup stub.",
|
||||
middleware=warmup_middleware,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=None,
|
||||
)
|
||||
|
||||
# Touch the compiled graph's stream_channels / nodes so any
|
||||
# remaining lazy schema work fires now instead of on first
|
||||
# real invocation.
|
||||
_ = list(getattr(compiled, "nodes", {}).keys())
|
||||
|
||||
del compiled
|
||||
logger.info(
|
||||
"[startup] Agent JIT warmup completed in %.3fs",
|
||||
_time.perf_counter() - t0,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[startup] Agent JIT warmup failed in %.3fs (non-fatal — first "
|
||||
"real request will pay the full compile cost)",
|
||||
_time.perf_counter() - t0,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||
|
|
@ -432,6 +562,7 @@ async def lifespan(app: FastAPI):
|
|||
await setup_checkpointer_tables()
|
||||
initialize_openrouter_integration()
|
||||
_start_openrouter_background_refresh()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
@ -443,6 +574,18 @@ async def lifespan(app: FastAPI):
|
|||
"Docs will be indexed on the next restart."
|
||||
)
|
||||
|
||||
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||
# doesn't leave half-warmed Pydantic schemas in an inconsistent
|
||||
# state.
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(_warm_agent_jit_caches()), timeout=20)
|
||||
except (TimeoutError, Exception): # pragma: no cover - defensive
|
||||
logging.getLogger(__name__).warning(
|
||||
"[startup] Agent JIT warmup hit timeout/error — skipping; "
|
||||
"first real request will pay the full compile cost."
|
||||
)
|
||||
|
||||
log_system_snapshot("startup_complete")
|
||||
|
||||
yield
|
||||
|
|
@ -452,6 +595,23 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
|
||||
def registration_allowed():
|
||||
"""Master auth kill switch keyed on the REGISTRATION_ENABLED env var.
|
||||
|
||||
Despite the name, this dependency does NOT only gate registration. When
|
||||
REGISTRATION_ENABLED is FALSE it intentionally blocks every auth surface
|
||||
that could mint or refresh a session for an attacker:
|
||||
|
||||
* email/password ``POST /auth/register``
|
||||
* email/password ``POST /auth/jwt/login``
|
||||
* the Google OAuth router (``/auth/google/authorize`` and the shared
|
||||
``/auth/google/callback`` handles both new signups and login for
|
||||
existing users, so flipping this off locks both)
|
||||
* the bespoke ``/auth/google/authorize-redirect`` helper used by the UI
|
||||
|
||||
Use it as a temporary "freeze all new sessions" lever during incident
|
||||
response. It is not a way to disable signup while keeping login working;
|
||||
for that, override ``UserManager.oauth_callback`` instead.
|
||||
"""
|
||||
if not config.REGISTRATION_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Registration is disabled"
|
||||
|
|
@ -596,32 +756,45 @@ app.add_middleware(
|
|||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth/jwt",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(rate_limit_login)],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
dependencies=[
|
||||
Depends(rate_limit_register),
|
||||
Depends(registration_allowed), # blocks registration when disabled
|
||||
],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(rate_limit_password_reset)],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
# Password / email-based auth routers are only mounted when not running in
|
||||
# Google-OAuth-only mode. Mounting them in OAuth-only prod previously left
|
||||
# POST /auth/register reachable, which is the bypass that allowed bots to
|
||||
# create non-OAuth users in spite of AUTH_TYPE=GOOGLE.
|
||||
if config.AUTH_TYPE != "GOOGLE":
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth/jwt",
|
||||
tags=["auth"],
|
||||
dependencies=[
|
||||
Depends(rate_limit_login),
|
||||
Depends(
|
||||
registration_allowed
|
||||
), # honour REGISTRATION_ENABLED kill switch on login too
|
||||
],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
dependencies=[
|
||||
Depends(rate_limit_register),
|
||||
Depends(registration_allowed),
|
||||
],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(rate_limit_password_reset)],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# /users/me (read/update profile) is needed in every auth mode, so it stays
|
||||
# mounted unconditionally.
|
||||
app.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
|
|
@ -679,16 +852,25 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
dependencies=[
|
||||
Depends(registration_allowed)
|
||||
], # blocks OAuth registration when disabled
|
||||
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
|
||||
# it blocks BOTH new OAuth signups AND login of existing OAuth users
|
||||
# (the fastapi-users OAuth router shares one callback for create+login,
|
||||
# so this dependency closes both paths together).
|
||||
dependencies=[Depends(registration_allowed)],
|
||||
)
|
||||
|
||||
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
||||
# This endpoint performs a server-side redirect instead of returning JSON
|
||||
# which fixes cross-site cookie issues where browsers don't send cookies
|
||||
# set via cross-origin fetch requests on subsequent redirects
|
||||
@app.get("/auth/google/authorize-redirect", tags=["auth"])
|
||||
# set via cross-origin fetch requests on subsequent redirects.
|
||||
# The registration_allowed dependency mirrors the OAuth router above so
|
||||
# the kill switch fails fast here instead of bouncing users to Google
|
||||
# only to 403 on the callback.
|
||||
@app.get(
|
||||
"/auth/google/authorize-redirect",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(registration_allowed)],
|
||||
)
|
||||
async def google_authorize_redirect(
|
||||
request: Request,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -22,10 +22,12 @@ def init_worker(**kwargs):
|
|||
initialize_image_gen_router,
|
||||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
|
|
|||
|
|
@ -47,11 +47,37 @@ def load_global_llm_configs():
|
|||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
for cfg in configs:
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
cfg.setdefault("anonymous_enabled", False)
|
||||
cfg.setdefault("seo_enabled", False)
|
||||
# Capability flag: explicit YAML override always wins. When the
|
||||
# operator has not annotated the model, defer to LiteLLM's
|
||||
# authoritative model map (`supports_vision`) which already
|
||||
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
|
||||
# vision-capable. Unknown / unmapped models default-allow so
|
||||
# we don't lock the user out of a freshly added third-party
|
||||
# entry; the streaming-task safety net (driven by
|
||||
# `is_known_text_only_chat_model`) is the only place a False
|
||||
# actually blocks a request.
|
||||
if "supports_image_input" not in cfg:
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
||||
slug = cfg["seo_slug"]
|
||||
|
|
@ -138,7 +164,11 @@ def load_global_image_gen_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_image_generation_configs", [])
|
||||
configs = data.get("global_image_generation_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
|
@ -153,7 +183,11 @@ def load_global_vision_llm_configs():
|
|||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("global_vision_llm_configs", [])
|
||||
configs = data.get("global_vision_llm_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
|
@ -254,6 +288,15 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||
)
|
||||
|
||||
# Image generation + vision LLM emission are opt-in (issue L).
|
||||
# OpenRouter's catalogue contains hundreds of image / vision
|
||||
# capable models; auto-injecting all of them into every
|
||||
# deployment would explode the model selector and surprise
|
||||
# operators upgrading from prior versions. Default to False so
|
||||
# admins must explicitly turn them on.
|
||||
settings.setdefault("image_generation_enabled", False)
|
||||
settings.setdefault("vision_enabled", False)
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
|
|
@ -296,10 +339,60 @@ def initialize_openrouter_integration():
|
|||
)
|
||||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
image_configs = service.get_image_generation_configs()
|
||||
if image_configs:
|
||||
config.GLOBAL_IMAGE_GEN_CONFIGS.extend(image_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(image_configs)} "
|
||||
f"image-generation models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
``config.GLOBAL_LLM_CONFIGS`` (OpenRouter dynamic models pulled
|
||||
from the OpenRouter catalogue + any operator-declared YAML pricing).
|
||||
|
||||
Must run AFTER ``initialize_openrouter_integration()`` so the
|
||||
OpenRouter catalogue is populated and BEFORE the first LLM call so
|
||||
``response_cost`` is available in ``TokenTrackingCallback``.
|
||||
|
||||
Failures are logged but never raised — startup must not be blocked
|
||||
by a missing pricing entry; the worst-case is the model debits 0.
|
||||
"""
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to register LiteLLM pricing: {e}")
|
||||
|
||||
|
||||
def initialize_llm_router():
|
||||
"""
|
||||
Initialize the LLM Router service for Auto mode.
|
||||
|
|
@ -444,14 +537,54 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Premium token quota settings
|
||||
PREMIUM_TOKEN_LIMIT = int(os.getenv("PREMIUM_TOKEN_LIMIT", "3000000"))
|
||||
# Premium credit (micro-USD) quota settings.
|
||||
#
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||
# still honoured for one release as fall-back values — the prior
|
||||
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||
# to micros, so operators upgrading without changing their .env still
|
||||
# get correct behaviour. A startup deprecation warning fires below if
|
||||
# they're set.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||
)
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_TOKENS_PER_UNIT = int(os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000"))
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||
)
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
|
||||
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||
# config's ``quota_reserve_tokens`` and clamps the result to this value
|
||||
# so a misconfigured "$1000/M" model can't lock the user's whole balance
|
||||
# on one call. Default $1.00 covers realistic worst-cases (Opus + 4K
|
||||
# reserve_tokens ≈ $0.36) with headroom.
|
||||
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||
|
||||
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||
):
|
||||
print(
|
||||
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||
"current Stripe price). The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_TOKENS_PER_UNIT is deprecated; rename to "
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||
"The old key will be removed in a future release."
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
ANON_TOKEN_LIMIT = int(os.getenv("ANON_TOKEN_LIMIT", "500000"))
|
||||
|
|
@ -464,6 +597,35 @@ class Config:
|
|||
# Default quota reserve tokens when not specified per-model
|
||||
QUOTA_MAX_RESERVE_PER_CALL = int(os.getenv("QUOTA_MAX_RESERVE_PER_CALL", "8000"))
|
||||
|
||||
# Per-image reservation (in micro-USD) used by ``billable_call`` for the
|
||||
# ``POST /image-generations`` endpoint when the global config does not
|
||||
# override it. $0.05 covers realistic worst-cases for current OpenAI /
|
||||
# OpenRouter image-gen pricing. Bypassed entirely for free configs.
|
||||
QUOTA_DEFAULT_IMAGE_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_PODCAST_RESERVE_MICROS", "200000")
|
||||
)
|
||||
|
||||
# Per-video-presentation reservation (in micro-USD). Fan-out of N
|
||||
# slide-scene generations (up to ``VIDEO_PRESENTATION_MAX_SLIDES=30``)
|
||||
# plus refine retries; can produce many premium completions. $1.00
|
||||
# covers worst-case. Tune via env.
|
||||
#
|
||||
# NOTE: this equals the existing ``QUOTA_MAX_RESERVE_MICROS`` default of
|
||||
# 1_000_000. The override path in ``billable_call`` bypasses the
|
||||
# per-call clamp in ``estimate_call_reserve_micros``, so this is the
|
||||
# *actual* hold — raising it via env is fine but means a single video
|
||||
# task can lock $1+ of credit.
|
||||
QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS = int(
|
||||
os.getenv("QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS", "1000000")
|
||||
)
|
||||
|
||||
# Abuse prevention: concurrent stream cap and CAPTCHA
|
||||
ANON_MAX_CONCURRENT_STREAMS = int(os.getenv("ANON_MAX_CONCURRENT_STREAMS", "2"))
|
||||
ANON_CAPTCHA_REQUEST_THRESHOLD = int(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,24 @@
|
|||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
#
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
|
|
@ -292,6 +310,17 @@ openrouter_integration:
|
|||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
system_instructions: ""
|
||||
|
|
|
|||
|
|
@ -731,6 +731,7 @@ class TokenUsage(BaseModel, TimestampMixin):
|
|||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||
total_tokens = Column(Integer, nullable=False, default=0)
|
||||
cost_micros = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
model_breakdown = Column(JSONB, nullable=True)
|
||||
call_details = Column(JSONB, nullable=True)
|
||||
|
||||
|
|
@ -1793,7 +1794,15 @@ class PagePurchase(Base, TimestampMixin):
|
|||
|
||||
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium token credits."""
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||
|
||||
Note: the table name is preserved (``premium_token_purchases``) for
|
||||
operational continuity even though the unit is now USD micro-credits
|
||||
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||
makes the unit conversion 1:1 numerically.
|
||||
"""
|
||||
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
|
@ -1810,7 +1819,7 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
|||
)
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
tokens_granted = Column(BigInteger, nullable=False)
|
||||
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
status = Column(
|
||||
|
|
@ -2109,16 +2118,16 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
@ -2241,16 +2250,16 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_tokens_limit = Column(
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_TOKEN_LIMIT,
|
||||
server_default=str(config.PREMIUM_TOKEN_LIMIT),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
premium_tokens_used = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_tokens_reserved = Column(
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,12 +68,25 @@ class EtlPipelineService:
|
|||
etl_service="VISION_LLM",
|
||||
content_type="image",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Special-case quota exhaustion so we log a clearer message
|
||||
# — the vision LLM didn't "fail", the user just ran out of
|
||||
# premium credit. Falling through to the document parser
|
||||
# is a graceful degradation: OCR/Unstructured still
|
||||
# extracts text from the image without burning credit.
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
if isinstance(exc, QuotaInsufficientError):
|
||||
logging.info(
|
||||
"Vision LLM quota exhausted for %s; falling back to document parser",
|
||||
request.filename,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"Vision LLM failed for %s, falling back to document parser",
|
||||
request.filename,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"No vision LLM provided, falling back to document parser for %s",
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
|
||||
|
|
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
|
|||
|
||||
enable_otel: bool
|
||||
|
||||
enable_desktop_local_filesystem: bool
|
||||
|
||||
@classmethod
|
||||
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||
return cls(**asdict(flags))
|
||||
return cls(
|
||||
**asdict(flags),
|
||||
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
|
|
|
|||
|
|
@ -649,13 +649,9 @@ async def list_composio_drive_folders(
|
|||
"""
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
Uses Composio's Google Drive tool execution path so managed OAuth tokens
|
||||
do not need to be exposed through connected account state.
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
|
|
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
items = []
|
||||
page_token = None
|
||||
error = None
|
||||
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
while True:
|
||||
page_items, next_token, page_error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_token=page_token,
|
||||
page_size=100,
|
||||
)
|
||||
if page_error:
|
||||
error = page_error
|
||||
break
|
||||
|
||||
items.extend(page_items)
|
||||
if not next_token:
|
||||
break
|
||||
page_token = next_token
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = (
|
||||
item.get("mimeType") == "application/vnd.google-apps.folder"
|
||||
)
|
||||
|
||||
items.sort(
|
||||
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
|
||||
)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
|
|
|
|||
|
|
@ -36,11 +36,17 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
|
|
@ -138,12 +192,18 @@ async def _execute_image_generation(
|
|||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -165,12 +225,18 @@ async def _execute_image_generation(
|
|||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +531,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +567,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
|||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
|||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.schemas import (
|
|||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -36,6 +37,39 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
|
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
|
|
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
|
|
@ -171,7 +236,7 @@ async def create_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return db_config
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -213,7 +278,7 @@ async def list_new_llm_configs(
|
|||
.limit(limit)
|
||||
)
|
||||
|
||||
return result.scalars().all()
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -268,7 +333,7 @@ async def get_new_llm_config(
|
|||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -360,7 +425,7 @@ async def update_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -607,6 +608,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -649,6 +651,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -665,6 +668,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -215,6 +215,12 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
|
|
@ -231,3 +237,24 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the reservation amount (in micro-USD) used when "
|
||||
"this image generation is premium. Falls back to "
|
||||
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class TokenUsageSummary(BaseModel):
|
|||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost_micros: int = 0
|
||||
model_breakdown: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
|
|||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (no DB column). Default
|
||||
# True matches the conservative-allow stance — a BYOK row that the
|
||||
# route forgot to augment is not pre-judged. The streaming-task
|
||||
# safety net is the only place a False actually blocks a request.
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map "
|
||||
"(``litellm.supports_vision``) — there is no DB column. "
|
||||
"Default True is the conservative-allow stance for unknown / "
|
||||
"unmapped models."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
|
|||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (see NewLLMConfigRead).
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map. "
|
||||
"Default True is the conservative-allow stance."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the model accepts image inputs (multimodal vision). "
|
||||
"Derived server-side: OpenRouter dynamic configs use "
|
||||
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
|
||||
"authoritative model map (``litellm.supports_vision``). The "
|
||||
"new-chat selector hints with a 'No image' badge when this is "
|
||||
"False and there are pending image attachments. The streaming "
|
||||
"task fails fast only when LiteLLM *explicitly* marks a model "
|
||||
"as text-only — unknown / unmapped models default-allow."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -70,13 +70,17 @@ class CreateTokenCheckoutSessionResponse(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseRead(BaseModel):
|
||||
"""Serialized premium token purchase record."""
|
||||
"""Serialized premium credit purchase record.
|
||||
|
||||
``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The
|
||||
schema name kept ``Token`` for API back-compat with pinned clients.
|
||||
"""
|
||||
|
||||
id: uuid.UUID
|
||||
stripe_checkout_session_id: str
|
||||
stripe_payment_intent_id: str | None = None
|
||||
quantity: int
|
||||
tokens_granted: int
|
||||
credit_micros_granted: int
|
||||
amount_total: int | None = None
|
||||
currency: str | None = None
|
||||
status: str
|
||||
|
|
@ -87,15 +91,19 @@ class TokenPurchaseRead(BaseModel):
|
|||
|
||||
|
||||
class TokenPurchaseHistoryResponse(BaseModel):
|
||||
"""Response containing the user's premium token purchases."""
|
||||
"""Response containing the user's premium credit purchases."""
|
||||
|
||||
purchases: list[TokenPurchaseRead]
|
||||
|
||||
|
||||
class TokenStripeStatusResponse(BaseModel):
|
||||
"""Response describing token-buying availability and current quota."""
|
||||
"""Response describing premium-credit-buying availability and balance.
|
||||
|
||||
All ``premium_credit_micros_*`` fields are in micro-USD; the FE
|
||||
divides by 1_000_000 to display USD.
|
||||
"""
|
||||
|
||||
token_buying_enabled: bool
|
||||
premium_tokens_used: int = 0
|
||||
premium_tokens_limit: int = 0
|
||||
premium_tokens_remaining: int = 0
|
||||
premium_credit_micros_used: int = 0
|
||||
premium_credit_micros_limit: int = 0
|
||||
premium_credit_micros_remaining: int = 0
|
||||
|
|
|
|||
|
|
@ -62,6 +62,15 @@ class VisionLLMConfigPublic(BaseModel):
|
|||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global vision LLM configs from YAML.
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
|
@ -73,3 +82,35 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
|||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the per-call reservation in *tokens* — "
|
||||
"converted to micro-USD via the model's input/output prices at "
|
||||
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||
),
|
||||
)
|
||||
input_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional input price in USD/token. Used by pricing_registration to "
|
||||
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||
),
|
||||
)
|
||||
output_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
|
|||
_healthy_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
def _cfg_supports_image_input(cfg: dict) -> bool:
|
||||
"""True if the global cfg can accept image inputs.
|
||||
|
||||
Prefers the explicit ``supports_image_input`` flag (set by the YAML
|
||||
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
|
||||
a YAML entry whose flag was somehow stripped doesn't get wrongly
|
||||
excluded. Default-allows on unknown — the streaming-task safety net
|
||||
is the actual block, not this filter.
|
||||
"""
|
||||
if "supports_image_input" in cfg:
|
||||
return bool(cfg.get("supports_image_input"))
|
||||
# Lazy import: provider_capabilities -> llm_config -> services chain;
|
||||
# importing at module load would create an init-order cycle through
|
||||
# ``app.config``.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
can't be picked as the thread's pin. Also excludes configs currently
|
||||
in runtime cooldown (e.g. temporary 429 bursts).
|
||||
|
||||
When ``requires_image_input`` is True (image turn), additionally
|
||||
filters out configs whose ``supports_image_input`` resolves to False
|
||||
so a text-only deployment can't be pinned for an image request.
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
|
|
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
|
|||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||
]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
|
@ -185,6 +220,15 @@ def _tier_of(cfg: dict) -> str:
|
|||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||
"""Return True for the operator-preferred premium Auto model."""
|
||||
return (
|
||||
_tier_of(cfg) == "premium"
|
||||
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||
"""Pick a config with quality-first ranking + deterministic spread.
|
||||
|
||||
|
|
@ -237,11 +281,20 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
exclude_config_ids: set[int] | None = None,
|
||||
requires_image_input: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
|
||||
When ``requires_image_input`` is True (the current turn carries an
|
||||
``image_url`` block), the candidate pool is filtered to vision-capable
|
||||
cfgs and any existing pin that can't accept image input is treated as
|
||||
invalid (force re-pin). If no vision-capable cfg is available the
|
||||
function raises ``ValueError`` so the streaming task surfaces the same
|
||||
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
|
||||
silently routing the image to a text-only deployment.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
|
|
@ -274,14 +327,24 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
||||
c
|
||||
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||
if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
if not candidates:
|
||||
if requires_image_input:
|
||||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable global LLM configs are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free.
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free
|
||||
# *or* the turn requires image input but the pin can't handle it.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
|
|
@ -311,6 +374,29 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
from_existing_pin=True,
|
||||
)
|
||||
if pinned_id is not None:
|
||||
# If the pin is *only* invalid because it can't handle the image
|
||||
# turn (it's still a healthy, usable config in the broader pool),
|
||||
# log that explicitly so operators can correlate the re-pin with
|
||||
# the user's image attachment instead of suspecting a cooldown.
|
||||
if requires_image_input:
|
||||
try:
|
||||
pinned_global = next(
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if int(c.get("id", 0)) == int(pinned_id)
|
||||
)
|
||||
except StopIteration:
|
||||
pinned_global = None
|
||||
if pinned_global is not None and not _cfg_supports_image_input(
|
||||
pinned_global
|
||||
):
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
|
|
@ -322,11 +408,19 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||
)
|
||||
if premium_eligible:
|
||||
eligible = candidates
|
||||
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
|
||||
preferred_premium = [
|
||||
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
|
||||
]
|
||||
eligible = preferred_premium or premium_candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
if not eligible:
|
||||
if requires_image_input:
|
||||
raise ValueError(
|
||||
"Auto mode could not find a vision-capable LLM config for this user and quota state"
|
||||
)
|
||||
raise ValueError(
|
||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
|
|
|||
566
surfsense_backend/app/services/billable_calls.py
Normal file
566
surfsense_backend/app/services/billable_calls.py
Normal file
|
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
Per-call billable wrapper for image generation, vision LLM extraction, and
|
||||
any other short-lived premium operation that must charge against the user's
|
||||
shared premium credit pool.
|
||||
|
||||
The ``billable_call`` async context manager encapsulates the standard
|
||||
"reserve → execute → finalize / release → record audit row" lifecycle in a
|
||||
single primitive so callers (the image-generation REST route and the
|
||||
vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||
|
||||
KEY DESIGN POINTS (issue A, B):
|
||||
|
||||
1. **Session isolation.** ``billable_call`` takes no caller transaction.
|
||||
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
|
||||
inside their own session context. Route callers use
|
||||
``shielded_async_session()`` by default; Celery callers can provide a
|
||||
worker-loop-safe session factory. This guarantees that quota
|
||||
commit/rollback can never accidentally flush or roll back rows the caller
|
||||
has staged in its main session (e.g. a freshly-created
|
||||
``ImageGeneration`` row).
|
||||
|
||||
2. **ContextVar safety.** The accumulator is scoped via
|
||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||
nested ``billable_call`` inside an outer chat turn cannot corrupt the
|
||||
chat turn's accumulator.
|
||||
|
||||
3. **Free configs are still audited.** Free calls bypass the reserve /
|
||||
finalize dance entirely but still record a ``TokenUsage`` audit row with
|
||||
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
|
||||
pipeline complete for analytics even when nothing is debited.
|
||||
|
||||
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
|
||||
responsible for translating that into HTTP 402. We *do not* catch the
|
||||
denial inside ``billable_call`` — letting it propagate also prevents
|
||||
the image-generation route from creating an ``ImageGeneration`` row
|
||||
for a request that never actually ran.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import shielded_async_session
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
record_token_usage,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUDIT_TIMEOUT_SECONDS = 10.0
|
||||
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
|
||||
{"video_presentation_generation", "podcast_generation"}
|
||||
)
|
||||
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
|
||||
|
||||
|
||||
class QuotaInsufficientError(Exception):
|
||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||
call because the user has exhausted their premium credit pool.
|
||||
|
||||
The route handler should catch this and return HTTP 402 Payment
|
||||
Required (or the equivalent for the surface area). Outside of the HTTP
|
||||
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
|
||||
callers may catch this and degrade gracefully — e.g. fall back to OCR
|
||||
when vision is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
usage_type: str,
|
||||
used_micros: int,
|
||||
limit_micros: int,
|
||||
remaining_micros: int,
|
||||
) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.used_micros = used_micros
|
||||
self.limit_micros = limit_micros
|
||||
self.remaining_micros = remaining_micros
|
||||
super().__init__(
|
||||
f"Premium credit exhausted for {usage_type}: "
|
||||
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
|
||||
)
|
||||
|
||||
|
||||
class BillingSettlementError(Exception):
|
||||
"""Raised when a premium call completed but credit settlement failed."""
|
||||
|
||||
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.user_id = user_id
|
||||
super().__init__(
|
||||
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
|
||||
)
|
||||
|
||||
|
||||
async def _rollback_safely(session: AsyncSession) -> None:
|
||||
rollback = getattr(session, "rollback", None)
|
||||
if rollback is not None:
|
||||
with suppress(Exception):
|
||||
await rollback()
|
||||
|
||||
|
||||
async def _record_audit_best_effort(
|
||||
*,
|
||||
session_factory: BillableSessionFactory,
|
||||
usage_type: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int,
|
||||
model_breakdown: dict[str, Any],
|
||||
call_details: dict[str, Any] | None,
|
||||
thread_id: int | None,
|
||||
message_id: int | None,
|
||||
audit_label: str,
|
||||
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Persist a TokenUsage row without letting audit failure block callers.
|
||||
|
||||
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
|
||||
audit insert or commit hangs, user-facing artifacts such as videos and
|
||||
podcasts must still be able to transition to READY after settlement.
|
||||
"""
|
||||
audit_thread_id = (
|
||||
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
|
||||
)
|
||||
|
||||
async def _persist() -> None:
|
||||
logger.info(
|
||||
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
async with session_factory() as audit_session:
|
||||
try:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=audit_thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
logger.info(
|
||||
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
except BaseException:
|
||||
await _rollback_safely(audit_session)
|
||||
raise
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
|
||||
"timeout=%.1fs total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
timeout_seconds,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_call(
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None = None,
|
||||
quota_reserve_micros_override: int | None = None,
|
||||
usage_type: str,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
billable_session_factory: BillableSessionFactory | None = None,
|
||||
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Wrap a single billable LLM/image call.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the credit pool to debit. For vision-LLM during
|
||||
indexing this is the *search-space owner* (issue M), not the
|
||||
triggering user.
|
||||
search_space_id: Required — recorded on the ``TokenUsage`` audit row.
|
||||
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
|
||||
the reserve/finalize dance but still records an audit row with
|
||||
the captured cost.
|
||||
base_model: Used by :func:`estimate_call_reserve_micros` to compute
|
||||
a worst-case reservation from LiteLLM's pricing table.
|
||||
quota_reserve_tokens: Optional per-config override for the chat-style
|
||||
reserve estimator (vision LLM uses this).
|
||||
quota_reserve_micros_override: Optional flat micro-USD reservation
|
||||
(image generation uses this — its cost shape is per-image, not
|
||||
per-token).
|
||||
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
|
||||
Recorded on the ``TokenUsage`` row.
|
||||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||
call_details: Optional per-call metadata (model name, parameters)
|
||||
forwarded to ``record_token_usage``.
|
||||
billable_session_factory: Optional async context factory used for
|
||||
reserve/finalize/release/audit sessions. Defaults to
|
||||
``shielded_async_session`` for route callers; Celery callers pass
|
||||
a worker-loop-safe session factory.
|
||||
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
|
||||
Audit failure is best-effort and does not undo successful
|
||||
settlement.
|
||||
|
||||
Yields:
|
||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||
the underlying LLM/image API while inside the ``async with``; the
|
||||
``TokenTrackingCallback`` populates the accumulator automatically.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||
"""
|
||||
is_premium = billing_tier == "premium"
|
||||
session_factory = billable_session_factory or shielded_async_session
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
# ---------- Free path: just audit -------------------------------
|
||||
if not is_premium:
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
# Always audit, even on exception, so we capture cost when
|
||||
# provider returns successfully but the caller raises later.
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="free",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
return
|
||||
|
||||
# ---------- Premium path: reserve → execute → finalize ----------
|
||||
if quota_reserve_micros_override is not None:
|
||||
reserve_micros = max(1, int(quota_reserve_micros_override))
|
||||
else:
|
||||
reserve_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model or "",
|
||||
quota_reserve_tokens=quota_reserve_tokens,
|
||||
)
|
||||
|
||||
request_id = str(uuid4())
|
||||
|
||||
async with session_factory() as quota_session:
|
||||
reserve_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_micros,
|
||||
)
|
||||
|
||||
if not reserve_result.allowed:
|
||||
logger.info(
|
||||
"[billable_call] reserve DENIED user=%s usage_type=%s "
|
||||
"reserve=%d used=%d limit=%d remaining=%d",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.used,
|
||||
reserve_result.limit,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=usage_type,
|
||||
used_micros=reserve_result.used,
|
||||
limit_micros=reserve_result.limit,
|
||||
remaining_micros=reserve_result.remaining,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
|
||||
"(remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
|
||||
try:
|
||||
yield acc
|
||||
except BaseException:
|
||||
# Release on any failure (including QuotaInsufficientError raised
|
||||
# from a downstream call, asyncio cancellation, etc.). We use
|
||||
# BaseException so cancellation also releases.
|
||||
try:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium_release failed for user=%s "
|
||||
"reserve_micros=%d (reservation will be GC'd by quota "
|
||||
"reconciliation if/when implemented)",
|
||||
user_id,
|
||||
reserve_micros,
|
||||
)
|
||||
raise
|
||||
|
||||
# ---------- Success: finalize + audit ----------------------------
|
||||
actual_micros = acc.total_cost_micros
|
||||
try:
|
||||
logger.info(
|
||||
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d thread=%s",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
thread_id,
|
||||
)
|
||||
async with session_factory() as quota_session:
|
||||
final_result = await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
actual_micros=actual_micros,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] finalize user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d → used=%d/%d (remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
final_result.used,
|
||||
final_result.limit,
|
||||
final_result.remaining,
|
||||
)
|
||||
except Exception as finalize_exc:
|
||||
# Last-ditch: if finalize itself fails, we must at least release
|
||||
# so the reservation doesn't leak.
|
||||
logger.exception(
|
||||
"[billable_call] premium_finalize failed for user=%s; "
|
||||
"attempting release",
|
||||
user_id,
|
||||
)
|
||||
try:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] release after finalize failure ALSO failed "
|
||||
"for user=%s",
|
||||
user_id,
|
||||
)
|
||||
raise BillingSettlementError(
|
||||
usage_type=usage_type,
|
||||
user_id=user_id,
|
||||
cause=finalize_exc,
|
||||
) from finalize_exc
|
||||
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="premium",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_agent_billing_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
so the same model bills for chat + downstream podcast/video. If the
|
||||
user is not premium-eligible, the pin service auto-restricts to free
|
||||
deployments — denial only happens later in
|
||||
``billable_call.premium_reserve`` if the pin really is premium and
|
||||
credit ran out mid-flow.
|
||||
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
|
||||
for any future direct-API path; today both Celery tasks always pass
|
||||
``thread_id``.
|
||||
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
|
||||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
avoid hoisting litellm side-effects (``litellm.callbacks =
|
||||
[token_tracker]``, ``litellm.drop_params``, etc.) into
|
||||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
resolution = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[agent_billing] Auto-mode pin resolution failed for "
|
||||
"search_space=%s thread=%s; falling back to free",
|
||||
search_space_id,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BillingSettlementError",
|
||||
"QuotaInsufficientError",
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
"billable_call",
|
||||
]
|
||||
|
||||
|
||||
# Re-export the config knob so callers don't have to import config just for
|
||||
# the default image reserve.
|
||||
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
|
@ -408,12 +408,37 @@ class ComposioService:
|
|||
files = []
|
||||
next_token = None
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
# Try direct access first, then nested
|
||||
files = data.get("files", []) or data.get("data", {}).get("files", [])
|
||||
files = (
|
||||
data.get("files", [])
|
||||
or (
|
||||
inner_data.get("files", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("files", [])
|
||||
)
|
||||
next_token = (
|
||||
data.get("nextPageToken")
|
||||
or data.get("next_page_token")
|
||||
or data.get("data", {}).get("nextPageToken")
|
||||
or (
|
||||
inner_data.get("nextPageToken")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("next_page_token")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("nextPageToken")
|
||||
or response_data.get("next_page_token")
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
files = data
|
||||
|
|
@ -819,24 +844,61 @@ class ComposioService:
|
|||
next_token = None
|
||||
result_size_estimate = None
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
messages = (
|
||||
data.get("messages", [])
|
||||
or data.get("data", {}).get("messages", [])
|
||||
or (
|
||||
inner_data.get("messages", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("messages", [])
|
||||
or data.get("emails", [])
|
||||
or (
|
||||
inner_data.get("emails", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("emails", [])
|
||||
)
|
||||
# Check for pagination token in various possible locations
|
||||
next_token = (
|
||||
data.get("nextPageToken")
|
||||
or data.get("next_page_token")
|
||||
or data.get("data", {}).get("nextPageToken")
|
||||
or data.get("data", {}).get("next_page_token")
|
||||
or (
|
||||
inner_data.get("nextPageToken")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("next_page_token")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("nextPageToken")
|
||||
or response_data.get("next_page_token")
|
||||
)
|
||||
# Extract resultSizeEstimate if available (Gmail API provides this)
|
||||
result_size_estimate = (
|
||||
data.get("resultSizeEstimate")
|
||||
or data.get("result_size_estimate")
|
||||
or data.get("data", {}).get("resultSizeEstimate")
|
||||
or data.get("data", {}).get("result_size_estimate")
|
||||
or (
|
||||
inner_data.get("resultSizeEstimate")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("result_size_estimate")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("resultSizeEstimate")
|
||||
or response_data.get("result_size_estimate")
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
messages = data
|
||||
|
|
@ -864,7 +926,7 @@ class ComposioService:
|
|||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
|
||||
tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
|
||||
params={"message_id": message_id}, # snake_case
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
|
@ -872,7 +934,13 @@ class ComposioService:
|
|||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
return result.get("data"), None
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
if isinstance(inner_data, dict):
|
||||
return inner_data.get("response_data", inner_data), None
|
||||
|
||||
return data, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
||||
|
|
@ -928,10 +996,27 @@ class ComposioService:
|
|||
# Try different possible response structures
|
||||
events = []
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
events = (
|
||||
data.get("items", [])
|
||||
or data.get("data", {}).get("items", [])
|
||||
or (
|
||||
inner_data.get("items", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("items", [])
|
||||
or data.get("events", [])
|
||||
or (
|
||||
inner_data.get("events", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("events", [])
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
events = data
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -2769,12 +2771,22 @@ class ConnectorService:
|
|||
"""
|
||||
Get all available (enabled) connector types for a search space.
|
||||
|
||||
Phase 1.4: results are cached per ``search_space_id`` for
|
||||
:data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
|
||||
identity — the cached value is plain data, safe to share across
|
||||
requests. Invalidate on connector add/update/delete via
|
||||
:func:`invalidate_connector_discovery_cache`.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
List of SearchSourceConnectorType enums for enabled connectors
|
||||
"""
|
||||
cached = _get_cached_connectors(search_space_id)
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
query = (
|
||||
select(SearchSourceConnector.connector_type)
|
||||
.filter(
|
||||
|
|
@ -2784,8 +2796,9 @@ class ConnectorService:
|
|||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
connector_types = result.scalars().all()
|
||||
return list(connector_types)
|
||||
connector_types = list(result.scalars().all())
|
||||
_set_cached_connectors(search_space_id, connector_types)
|
||||
return connector_types
|
||||
|
||||
async def get_available_document_types(
|
||||
self,
|
||||
|
|
@ -2794,12 +2807,22 @@ class ConnectorService:
|
|||
"""
|
||||
Get all document types that have at least one document in the search space.
|
||||
|
||||
Phase 1.4: cached per ``search_space_id`` for
|
||||
:data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
|
||||
:func:`invalidate_connector_discovery_cache` when a connector
|
||||
finishes indexing new documents (or document types are otherwise
|
||||
added/removed).
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
List of document type strings that have documents indexed
|
||||
"""
|
||||
cached = _get_cached_doc_types(search_space_id)
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
from sqlalchemy import distinct
|
||||
|
||||
from app.db import Document
|
||||
|
|
@ -2809,5 +2832,164 @@ class ConnectorService:
|
|||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
doc_types = result.scalars().all()
|
||||
return [str(dt) for dt in doc_types]
|
||||
doc_types = [str(dt) for dt in result.scalars().all()]
|
||||
_set_cached_doc_types(search_space_id, doc_types)
|
||||
return doc_types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connector / document-type discovery TTL cache (Phase 1.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Both ``get_available_connectors`` and ``get_available_document_types`` are
|
||||
# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
|
||||
# hits Postgres and contributes to per-turn agent build latency. Their
|
||||
# results change infrequently — only when the user adds/edits/removes a
|
||||
# connector, or when an indexer commits a new document type. A short TTL
|
||||
# cache (default 30s, env-tunable) collapses N concurrent calls into one
|
||||
# DB roundtrip with bounded staleness.
|
||||
#
|
||||
# Invalidation: connector mutation routes (create / update / delete) call
|
||||
# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
|
||||
# entry for the affected space. Multi-replica deployments still pay one
|
||||
# DB roundtrip per replica per TTL window, which is fine — staleness is
|
||||
# bounded and the alternative (cross-replica fanout) is not worth the
|
||||
# coupling here.
|
||||
|
||||
_DISCOVERY_TTL_SECONDS: float = float(
|
||||
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
|
||||
)
|
||||
|
||||
# Per-search-space caches. Keyed by ``search_space_id``; value is
|
||||
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
|
||||
# read-mostly workload, sub-microsecond contention.
|
||||
_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
|
||||
_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
|
||||
_cache_lock = Lock()
|
||||
|
||||
|
||||
def _get_cached_connectors(
|
||||
search_space_id: int,
|
||||
) -> list[SearchSourceConnectorType] | None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return None
|
||||
with _cache_lock:
|
||||
entry = _connectors_cache.get(search_space_id)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, payload = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _set_cached_connectors(
|
||||
search_space_id: int, payload: list[SearchSourceConnectorType]
|
||||
) -> None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return
|
||||
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||
with _cache_lock:
|
||||
_connectors_cache[search_space_id] = (expires_at, list(payload))
|
||||
|
||||
|
||||
def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return None
|
||||
with _cache_lock:
|
||||
entry = _doc_types_cache.get(search_space_id)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, payload = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return
|
||||
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||
with _cache_lock:
|
||||
_doc_types_cache[search_space_id] = (expires_at, list(payload))
|
||||
|
||||
|
||||
def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
|
||||
"""Drop cached discovery results for ``search_space_id`` (or all spaces).
|
||||
|
||||
Connector CRUD routes / indexer pipelines call this when they mutate
|
||||
the rows backing :func:`ConnectorService.get_available_connectors` /
|
||||
:func:`get_available_document_types`. ``None`` clears every space —
|
||||
useful in tests and on bulk imports.
|
||||
"""
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_connectors_cache.clear()
|
||||
_doc_types_cache.clear()
|
||||
else:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_connectors_cache.clear()
|
||||
else:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_doc_types_cache.clear()
|
||||
else:
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _register_invalidation_listeners() -> None:
|
||||
"""Wire SQLAlchemy ORM events so cache stays consistent automatically.
|
||||
|
||||
Listening on ``after_insert`` / ``after_update`` / ``after_delete``
|
||||
means every successful INSERT/UPDATE/DELETE that goes through the ORM
|
||||
invalidates the affected search space's cached discovery payload —
|
||||
no need to sprinkle ``invalidate_*`` calls across 30+ connector
|
||||
routes. Bulk operations that bypass the ORM (e.g.
|
||||
``session.execute(insert(...))`` without a mapped object) still need
|
||||
explicit invalidation; document indexers already commit through the
|
||||
ORM so document-type discovery is covered.
|
||||
"""
|
||||
from sqlalchemy import event
|
||||
|
||||
# Imported here (not at module top) to avoid a circular import:
|
||||
# app.services.connector_service is itself imported from app.db's
|
||||
# ecosystem indirectly via several CRUD modules.
|
||||
from app.db import Document, SearchSourceConnector
|
||||
|
||||
def _connector_changed(_mapper, _connection, target) -> None:
|
||||
sid = getattr(target, "search_space_id", None)
|
||||
if sid is not None:
|
||||
_invalidate_connectors_only(int(sid))
|
||||
|
||||
def _document_changed(_mapper, _connection, target) -> None:
|
||||
sid = getattr(target, "search_space_id", None)
|
||||
if sid is not None:
|
||||
_invalidate_doc_types_only(int(sid))
|
||||
|
||||
for evt in ("after_insert", "after_update", "after_delete"):
|
||||
event.listen(SearchSourceConnector, evt, _connector_changed)
|
||||
event.listen(Document, evt, _document_changed)
|
||||
|
||||
|
||||
try:
|
||||
_register_invalidation_listeners()
|
||||
except Exception: # pragma: no cover - defensive; never block module import
|
||||
import logging as _logging
|
||||
|
||||
_logging.getLogger(__name__).exception(
|
||||
"Failed to register connector discovery cache invalidation listeners; "
|
||||
"stale cache risk: explicit invalidate_connector_discovery_cache calls "
|
||||
"may be required."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -78,14 +78,49 @@ class GmailToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
def _unwrap_composio_data(self, data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner)
|
||||
return inner
|
||||
return data
|
||||
|
||||
async def _execute_composio_gmail_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Any, str | None]:
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Gmail error")
|
||||
return self._unwrap_composio_data(result.get("data")), None
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if self._is_composio_connector(connector):
|
||||
raise ValueError(
|
||||
"Composio Gmail connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -139,6 +174,12 @@ class GmailToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
if self._is_composio_connector(connector):
|
||||
_profile, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
|
|
@ -221,14 +262,21 @@ class GmailToolMetadataService:
|
|||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
profile, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
acc_dict["email"] = profile.get("emailAddress", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
|
|
@ -298,6 +346,23 @@ class GmailToolMetadataService:
|
|||
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||
"""
|
||||
try:
|
||||
if self._is_composio_connector(connector):
|
||||
if not draft_id:
|
||||
draft_id = await self._find_composio_draft_id(connector, message_id)
|
||||
if not draft_id:
|
||||
return None
|
||||
|
||||
draft, error = await self._execute_composio_gmail_tool(
|
||||
connector,
|
||||
"GMAIL_GET_DRAFT",
|
||||
{"user_id": "me", "draft_id": draft_id, "format": "full"},
|
||||
)
|
||||
if error or not isinstance(draft, dict):
|
||||
return None
|
||||
|
||||
payload = draft.get("message", {}).get("payload", {})
|
||||
return self._extract_body_from_payload(payload)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
|
|
@ -326,6 +391,33 @@ class GmailToolMetadataService:
|
|||
)
|
||||
return None
|
||||
|
||||
async def _find_composio_draft_id(
|
||||
self, connector: SearchSourceConnector, message_id: str
|
||||
) -> str | None:
|
||||
page_token = ""
|
||||
while True:
|
||||
params: dict[str, Any] = {
|
||||
"user_id": "me",
|
||||
"max_results": 100,
|
||||
"verbose": False,
|
||||
}
|
||||
if page_token:
|
||||
params["page_token"] = page_token
|
||||
|
||||
data, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_LIST_DRAFTS", params
|
||||
)
|
||||
if error or not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
for draft in data.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft.get("id")
|
||||
|
||||
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||
if not page_token:
|
||||
return None
|
||||
|
||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -21,7 +22,6 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
|
|||
logger.warning("Document %s not found in KB", document_id)
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
calendar_id = (document.document_metadata or {}).get(
|
||||
"calendar_id"
|
||||
) or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
connector = await self._get_connector(connector_id)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
composio_result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLECALENDAR_EVENTS_GET",
|
||||
params={"calendar_id": calendar_id, "event_id": event_id},
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not composio_result.get("success"):
|
||||
raise RuntimeError(
|
||||
composio_result.get("error", "Unknown Composio Calendar error")
|
||||
)
|
||||
live_event = composio_result.get("data", {})
|
||||
if isinstance(live_event, dict):
|
||||
live_event = live_event.get("data", live_event)
|
||||
if isinstance(live_event, dict):
|
||||
live_event = live_event.get("response_data", live_event)
|
||||
else:
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_summary = live_event.get("summary", "")
|
||||
description = live_event.get("description", "")
|
||||
|
|
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
|
|||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
|
||||
result = await self.db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
|
|
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
|
|||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {connector_id} not found")
|
||||
return connector
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
connector = await self._get_connector(connector_id)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
raise ValueError(
|
||||
"Composio Calendar connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
async def _execute_composio_calendar_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
) -> tuple[dict | list | None, str | None]:
|
||||
service = ComposioService()
|
||||
result = await service.execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Calendar error")
|
||||
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner), None
|
||||
return inner, None
|
||||
return data, None
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if self._is_composio_connector(connector):
|
||||
raise ValueError(
|
||||
"Composio Calendar connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
if self._is_composio_connector(connector):
|
||||
_data, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_GET_CALENDAR",
|
||||
{"calendar_id": "primary"},
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
|
|
@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
|
|||
timezone_str = ""
|
||||
if connector:
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
cal_list, cal_error = await self._execute_composio_calendar_tool(
|
||||
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
|
||||
)
|
||||
if cal_error:
|
||||
raise RuntimeError(cal_error)
|
||||
(
|
||||
settings,
|
||||
settings_error,
|
||||
) = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_SETTINGS_GET",
|
||||
{"setting": "timezone"},
|
||||
)
|
||||
if not settings_error and isinstance(settings, dict):
|
||||
timezone_str = settings.get("value", "")
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
for cal in cal_list.get("items", []):
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
|
||||
calendar_items = []
|
||||
if isinstance(cal_list, dict):
|
||||
calendar_items = (
|
||||
cal_list.get("items") or cal_list.get("calendars") or []
|
||||
)
|
||||
elif isinstance(cal_list, list):
|
||||
calendar_items = cal_list
|
||||
|
||||
for cal in calendar_items:
|
||||
calendars.append(
|
||||
{
|
||||
"id": cal.get("id", ""),
|
||||
|
|
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
|
|||
"primary": cal.get("primary", False),
|
||||
}
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch calendars/timezone for connector %s",
|
||||
|
|
@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
|
|||
|
||||
event_dict = event.to_dict()
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
calendar_id = event.calendar_id or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
live_event, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_EVENTS_GET",
|
||||
{"calendar_id": calendar_id, "event_id": event.event_id},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
||||
event_dict["description"] = live_event.get(
|
||||
|
|
@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
|
|||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
live_resolved = await self._resolve_live_event(
|
||||
search_space_id, user_id, event_ref
|
||||
)
|
||||
if not live_resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, "
|
||||
"(2) the event name is different, or "
|
||||
"(3) the connected calendar account cannot access it."
|
||||
)
|
||||
}
|
||||
|
||||
connector, live_event = live_resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
"account": acc_dict,
|
||||
"event": self._event_dict_from_live_event(live_event),
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
|
|
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
|
|||
if row:
|
||||
return row[0], row[1]
|
||||
return None
|
||||
|
||||
async def _resolve_live_event(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> tuple[SearchSourceConnector, dict] | None:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
for connector in connectors:
|
||||
try:
|
||||
events = await self._search_live_events(connector, event_ref)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to search live calendar events for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
if not events:
|
||||
continue
|
||||
|
||||
normalized_ref = event_ref.strip().lower()
|
||||
exact_match = next(
|
||||
(
|
||||
event
|
||||
for event in events
|
||||
if event.get("summary", "").strip().lower() == normalized_ref
|
||||
),
|
||||
None,
|
||||
)
|
||||
return connector, exact_match or events[0]
|
||||
|
||||
return None
|
||||
|
||||
async def _search_live_events(
|
||||
self, connector: SearchSourceConnector, event_ref: str
|
||||
) -> list[dict]:
|
||||
if self._is_composio_connector(connector):
|
||||
data, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_EVENTS_LIST",
|
||||
{
|
||||
"calendar_id": "primary",
|
||||
"q": event_ref,
|
||||
"max_results": 10,
|
||||
"single_events": True,
|
||||
"order_by": "startTime",
|
||||
},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
if isinstance(data, dict):
|
||||
return data.get("items") or data.get("events") or []
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.list(
|
||||
calendarId="primary",
|
||||
q=event_ref,
|
||||
maxResults=10,
|
||||
singleEvents=True,
|
||||
orderBy="startTime",
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
return response.get("items", [])
|
||||
|
||||
def _event_dict_from_live_event(self, event: dict) -> dict:
|
||||
start_data = event.get("start", {})
|
||||
end_data = event.get("end", {})
|
||||
return {
|
||||
"event_id": event.get("id", ""),
|
||||
"summary": event.get("summary", "No Title"),
|
||||
"start": start_data.get("dateTime", start_data.get("date", "")),
|
||||
"end": end_data.get("dateTime", end_data.get("date", "")),
|
||||
"description": event.get("description", ""),
|
||||
"location": event.get("location", ""),
|
||||
"attendees": [
|
||||
{
|
||||
"email": attendee.get("email", ""),
|
||||
"responseStatus": attendee.get("responseStatus", ""),
|
||||
}
|
||||
for attendee in event.get("attendees", [])
|
||||
],
|
||||
"calendar_id": event.get("calendarId", "primary"),
|
||||
"document_id": None,
|
||||
"indexed_at": None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
async def _execute_composio_drive_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
) -> tuple[dict | list | None, str | None]:
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Drive error")
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner), None
|
||||
return inner, None
|
||||
return data, None
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||
|
||||
|
|
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if self._is_composio_connector(connector):
|
||||
_data, error = await self._execute_composio_drive_tool(
|
||||
connector,
|
||||
"GOOGLEDRIVE_LIST_FILES",
|
||||
{
|
||||
"q": "trashed = false",
|
||||
"page_size": 1,
|
||||
"fields": "files(id)",
|
||||
},
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
await client.list_files(
|
||||
query="trashed = false", page_size=1, fields="files(id)"
|
||||
|
|
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
|
|||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if self._is_composio_connector(connector):
|
||||
data, error = await self._execute_composio_drive_tool(
|
||||
connector,
|
||||
"GOOGLEDRIVE_LIST_FILES",
|
||||
{
|
||||
"q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||
"fields": "files(id,name)",
|
||||
"page_size": 50,
|
||||
},
|
||||
)
|
||||
if error:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s",
|
||||
connector_id,
|
||||
error,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
folders = []
|
||||
if isinstance(data, dict):
|
||||
folders = data.get("files", [])
|
||||
elif isinstance(data, list):
|
||||
folders = data
|
||||
parent_folders[connector_id] = [
|
||||
{"folder_id": f["id"], "name": f["name"]}
|
||||
for f in folders
|
||||
if f.get("id") and f.get("name")
|
||||
]
|
||||
continue
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
|
||||
folders, _, error = await client.list_files(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
|
|
@ -152,12 +154,12 @@ class ImageGenRouterService:
|
|||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
|
|
@ -165,9 +167,16 @@ class ImageGenRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
# Resolve ``api_base`` so deployments don't silently inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||
# the wrong provider (see ``provider_api_base`` docstring).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
|
|
|
|||
|
|
@ -134,42 +134,14 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
|
|
@ -466,14 +438,14 @@ class LLMRouterService:
|
|||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.services.llm_router_service import (
|
|||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -496,8 +497,14 @@ async def get_vision_llm(
|
|||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
|
|
@ -519,6 +526,8 @@ async def get_vision_llm(
|
|||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
|
|
@ -526,6 +535,13 @@ async def get_vision_llm(
|
|||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
|
|
@ -541,29 +557,46 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
|
||||
)
|
||||
provider_prefix = global_cfg["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{global_cfg['model_name']}"
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
if global_cfg.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=global_cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=global_cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
|
|
@ -578,20 +611,26 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
|
||||
provider_prefix = vision_cfg.custom_provider
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{vision_cfg.model_name}"
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
if vision_cfg.api_base:
|
||||
litellm_kwargs["api_base"] = vision_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=vision_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=vision_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _is_image_output_model(model: dict) -> bool:
|
||||
"""Return True if the model can produce image output.
|
||||
|
||||
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
|
||||
``["image"]`` for pure image generators, ``["text", "image"]`` for
|
||||
multi-modal generators that also emit captions). We accept any model
|
||||
that can output images; the call site decides whether to use the
|
||||
image-generation API or chat completion.
|
||||
"""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def _is_vision_input_model(model: dict) -> bool:
|
||||
"""Return True if the model can ingest an image AND emit text.
|
||||
|
||||
OpenRouter's ``architecture.input_modalities`` lists what the model
|
||||
accepts; ``output_modalities`` lists what it produces. A vision LLM
|
||||
is a model that takes images in and produces text out — i.e. it can
|
||||
answer questions about a screenshot or extract content from an
|
||||
image. Pure image-to-image models (e.g. style transfer) and
|
||||
text-only models are excluded.
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
output_mods = arch.get("output_modalities", []) or []
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _supports_image_input(model: dict) -> bool:
|
||||
"""Return True if the model accepts ``image`` in its input modalities.
|
||||
|
||||
Differs from :func:`_is_vision_input_model` in that it does NOT
|
||||
require text output — chat-tab models always emit text already (the
|
||||
chat catalog filters by ``_is_text_output_model``), so the only
|
||||
extra capability we need to track per chat config is whether the
|
||||
model can ingest user-attached images. The chat selector and the
|
||||
streaming task both key off this flag to prevent hitting an
|
||||
OpenRouter 404 ``"No endpoints found that support image input"``
|
||||
when the user uploads an image and selects a text-only model
|
||||
(DeepSeek V3, Llama 3.x base, etc.).
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
return "image" in input_mods
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
|
|
@ -175,6 +222,32 @@ async def _fetch_models_async() -> list[dict] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
|
||||
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
|
||||
|
||||
Pricing values are kept as the raw OpenRouter strings (e.g.
|
||||
``"0.000003"``); ``pricing_registration`` converts them to floats
|
||||
when registering with LiteLLM. Models with missing or malformed
|
||||
pricing are simply omitted — operator-side risk if any of those are
|
||||
premium.
|
||||
"""
|
||||
pricing: dict[str, dict[str, str]] = {}
|
||||
for model in raw_models:
|
||||
model_id = str(model.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
p = model.get("pricing") or {}
|
||||
prompt = p.get("prompt")
|
||||
completion = p.get("completion")
|
||||
if prompt is None and completion is None:
|
||||
continue
|
||||
pricing[model_id] = {
|
||||
"prompt": str(prompt) if prompt is not None else "",
|
||||
"completion": str(completion) if completion is not None else "",
|
||||
}
|
||||
return pricing
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
|
|
@ -266,6 +339,13 @@ def _generate_configs(
|
|||
# account-wide quota, so per-deployment routing can't spread load
|
||||
# there — it just drains the shared bucket faster.
|
||||
"router_pool_eligible": tier == "premium",
|
||||
# Capability flag derived from ``architecture.input_modalities``.
|
||||
# Read by the new-chat selector to dim image-incompatible models
|
||||
# when the user has pending image attachments, and by
|
||||
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||
# OpenRouter request would otherwise 404 with
|
||||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": _supports_image_input(model),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
|
|
@ -282,6 +362,171 @@ def _generate_configs(
|
|||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter image-generation models into global image-gen
|
||||
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.output_modalities contains "image"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
|
||||
``_has_sufficient_context``) because tool calls and context windows are
|
||||
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
|
||||
derived per model the same way as chat (``_openrouter_tier``).
|
||||
|
||||
Cost is intentionally *not* registered with LiteLLM at startup
|
||||
(``pricing_registration`` skips image gen): OpenRouter image-gen
|
||||
models are not in LiteLLM's native cost map and OpenRouter populates
|
||||
``response_cost`` directly from the response header. A defensive
|
||||
branch in ``_extract_cost_usd`` handles the rare case where
|
||||
``usage.cost`` is missing — see ``token_tracking_service``.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in image_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
|
||||
# ``image_generation/transformation`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
|
|
@ -300,6 +545,19 @@ class OpenRouterIntegrationService:
|
|||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||
self._enrich_task: asyncio.Task | None = None
|
||||
# Raw OpenRouter pricing per model_id, captured at the same time
|
||||
# we generate configs. Consumed by ``pricing_registration`` to
|
||||
# teach LiteLLM the per-token cost of every dynamic deployment so
|
||||
# the success-callback can populate ``response_cost`` correctly.
|
||||
self._raw_pricing: dict[str, dict[str, str]] = {}
|
||||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -329,8 +587,32 @@ class OpenRouterIntegrationService:
|
|||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._raw_models = raw_models
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
self._image_configs = _generate_image_gen_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: image-gen emission ON (%d models)",
|
||||
len(self._image_configs),
|
||||
)
|
||||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
|
|
@ -369,6 +651,8 @@ class OpenRouterIntegrationService:
|
|||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
self._raw_models = raw_models
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
|
@ -382,6 +666,29 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
|
|
@ -407,6 +714,21 @@ class OpenRouterIntegrationService:
|
|||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
# Re-register LiteLLM pricing for the freshly fetched catalogue
|
||||
# so newly added OR models bill correctly on their first call.
|
||||
# Runs before the router rebuild because the router may issue
|
||||
# cost-table lookups during deployment registration.
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
|
||||
)
|
||||
|
||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||
# out; a refresh also needs to pick up any static-config edits and
|
||||
|
|
@ -635,3 +957,34 @@ class OpenRouterIntegrationService:
|
|||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
|
||||
def get_image_generation_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter image-generation configs (empty
|
||||
list when the ``image_generation_enabled`` flag is off).
|
||||
|
||||
Each entry already has ``billing_tier`` derived per-model from
|
||||
OpenRouter's signals and is shaped to drop directly into
|
||||
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
|
||||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
|
||||
values are the strings OpenRouter publishes (USD per token),
|
||||
never converted to floats here so the caller can decide how to
|
||||
handle malformed or unset entries.
|
||||
"""
|
||||
return dict(self._raw_pricing)
|
||||
|
|
|
|||
274
surfsense_backend/app/services/pricing_registration.py
Normal file
274
surfsense_backend/app/services/pricing_registration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Pricing registration with LiteLLM.
|
||||
|
||||
Many models reach our LiteLLM callback without LiteLLM knowing their
|
||||
per-token cost — namely:
|
||||
|
||||
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
|
||||
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
|
||||
pricing table).
|
||||
* Static YAML deployments whose ``base_model`` name is operator-defined
|
||||
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
|
||||
not in LiteLLM's table either.
|
||||
|
||||
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
|
||||
and the user gets billed nothing — a fail-safe but wrong answer for a
|
||||
cost-based credit system. This module runs once at startup, after the
|
||||
OpenRouter integration has fetched its catalogue, and registers each
|
||||
known model's pricing with ``litellm.register_model()`` under multiple
|
||||
plausible alias keys (LiteLLM's cost lookup may use any of them
|
||||
depending on whether the call went through the Router, ChatLiteLLM,
|
||||
or a direct ``acompletion``).
|
||||
|
||||
Operators who run a custom Azure deployment whose ``base_model`` name
|
||||
isn't in LiteLLM's table can declare per-token pricing inline in
|
||||
``global_llm_config.yaml`` via ``input_cost_per_token`` and
|
||||
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
|
||||
that declaration the model's calls debit 0 — never overbilled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float:
|
||||
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
return f if f > 0 else 0.0
|
||||
|
||||
|
||||
def _alias_set_for_openrouter(model_id: str) -> list[str]:
|
||||
"""Return the alias keys to register an OpenRouter model under.
|
||||
|
||||
LiteLLM's cost-callback lookup key varies by call path:
|
||||
- Router with ``model="openrouter/X"`` → kwargs["model"] is
|
||||
typically ``openrouter/X``.
|
||||
- LiteLLM's own provider routing may strip the prefix and pass the
|
||||
bare ``X`` to the cost-table lookup.
|
||||
Registering under both keeps the lookup hermetic regardless of
|
||||
which path the call took.
|
||||
"""
|
||||
aliases = [f"openrouter/{model_id}", model_id]
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
|
||||
"""Return the alias keys to register a static YAML deployment under.
|
||||
|
||||
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
|
||||
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
|
||||
bare ``model_name`` because callbacks sometimes see whichever was
|
||||
configured first.
|
||||
"""
|
||||
provider_lower = (provider or "").lower()
|
||||
aliases: list[str] = []
|
||||
if base_model:
|
||||
aliases.append(base_model)
|
||||
if provider_lower and base_model:
|
||||
aliases.append(f"{provider_lower}/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(model_name)
|
||||
if provider_lower and model_name and model_name != base_model:
|
||||
aliases.append(f"{provider_lower}/{model_name}")
|
||||
# Azure deployments often surface as "azure/<name>"; normalise the
|
||||
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
|
||||
if provider_lower == "azure_openai":
|
||||
if base_model:
|
||||
aliases.append(f"azure/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(f"azure/{model_name}")
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _register(
|
||||
aliases: list[str],
|
||||
*,
|
||||
input_cost: float,
|
||||
output_cost: float,
|
||||
provider: str,
|
||||
mode: str = "chat",
|
||||
) -> int:
|
||||
"""Register a single pricing entry under every alias in ``aliases``.
|
||||
|
||||
Returns the count of aliases successfully registered.
|
||||
"""
|
||||
payload: dict[str, dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
payload[alias] = {
|
||||
"input_cost_per_token": input_cost,
|
||||
"output_cost_per_token": output_cost,
|
||||
"litellm_provider": provider,
|
||||
"mode": mode,
|
||||
}
|
||||
if not payload:
|
||||
return 0
|
||||
try:
|
||||
litellm.register_model(payload)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[PricingRegistration] register_model failed for aliases=%s: %s",
|
||||
aliases,
|
||||
exc,
|
||||
)
|
||||
return 0
|
||||
return len(payload)
|
||||
|
||||
|
||||
def _register_chat_shape_configs(
|
||||
configs: list[dict],
|
||||
*,
|
||||
or_pricing: dict[str, dict[str, str]],
|
||||
label: str,
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Common loop that registers per-token pricing for a list of "chat-shape"
|
||||
configs (chat or vision LLM — both use ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
|
||||
|
||||
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
|
||||
"""
|
||||
registered_models = 0
|
||||
registered_aliases = 0
|
||||
skipped_no_pricing = 0
|
||||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("provider") or "").upper()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
||||
if provider == "OPENROUTER":
|
||||
entry = or_pricing.get(model_name)
|
||||
if entry:
|
||||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_openrouter(model_name)
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider="openrouter",
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
continue
|
||||
|
||||
input_cost = _safe_float(
|
||||
cfg.get("input_cost_per_token")
|
||||
or litellm_params.get("input_cost_per_token")
|
||||
)
|
||||
output_cost = _safe_float(
|
||||
cfg.get("output_cost_per_token")
|
||||
or litellm_params.get("output_cost_per_token")
|
||||
)
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider=provider_slug,
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
|
||||
logger.info(
|
||||
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
|
||||
"%d configs had no pricing data; sample registered keys=%s",
|
||||
label,
|
||||
registered_models,
|
||||
registered_aliases,
|
||||
skipped_no_pricing,
|
||||
sample_keys,
|
||||
)
|
||||
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
|
||||
|
||||
|
||||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
startup fetch) and converts the per-token strings to floats. For
|
||||
vision configs that carry pricing inline (``input_cost_per_token`` /
|
||||
``output_cost_per_token`` set on the cfg itself) we fall back to
|
||||
those values when the OR cache misses the model.
|
||||
2. Anything else: looks for operator-declared
|
||||
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
|
||||
config block (top-level or nested under ``litellm_params``).
|
||||
|
||||
**Image generation is intentionally NOT registered here.** The cost
|
||||
shape for image-gen is per-image (``output_cost_per_image``), not
|
||||
per-token, and LiteLLM's ``register_model`` doesn't accept those
|
||||
keys via the chat-cost path. OpenRouter image-gen models populate
|
||||
``response_cost`` directly from their response header instead, and
|
||||
Azure-native image-gen models are already in LiteLLM's cost map.
|
||||
|
||||
Calls without a resolved pair of costs are skipped, not registered
|
||||
with zeros — operators who forget pricing get a "$0 debit" warning
|
||||
in ``TokenTrackingCallback`` rather than silently overwriting any
|
||||
pricing LiteLLM might know natively.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
or_pricing: dict[str, dict[str, str]] = {}
|
||||
try:
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
|
||||
if OpenRouterIntegrationService.is_initialized():
|
||||
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
|
||||
)
|
||||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
106
surfsense_backend/app/services/provider_api_base.py
Normal file
106
surfsense_backend/app/services/provider_api_base.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||
|
||||
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||
provider).
|
||||
|
||||
The chat router has had this defense for a while
|
||||
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||
into a tiny standalone helper so vision and image-gen can share the same
|
||||
source of truth without an inter-service circular import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||
|
||||
Only providers with a well-known, stable public base URL are listed —
|
||||
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
so their existing config-driven behaviour is preserved."""
|
||||
|
||||
|
||||
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
"""Canonical provider key (uppercase) → base URL.
|
||||
|
||||
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||
has its own base URL)."""
|
||||
|
||||
|
||||
def resolve_api_base(
|
||||
*,
|
||||
provider: str | None,
|
||||
provider_prefix: str | None,
|
||||
config_api_base: str | None,
|
||||
) -> str | None:
|
||||
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||
|
||||
Cascade (first non-empty wins):
|
||||
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||
provider integration apply its own default (e.g. AzureOpenAI's
|
||||
deployment-derived URL, custom provider's per-deployment URL).
|
||||
|
||||
Args:
|
||||
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||
``"DEEPSEEK"``). Case-insensitive.
|
||||
provider_prefix: The LiteLLM model-string prefix the same call
|
||||
site builds for the model id (e.g. ``"openrouter"``,
|
||||
``"groq"``). Case-insensitive.
|
||||
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||
OpenRouter dynamic config. Empty / whitespace-only means
|
||||
"missing" — the resolver still applies the cascade.
|
||||
|
||||
Returns:
|
||||
A URL string, or ``None`` if no default applies for this provider.
|
||||
"""
|
||||
if config_api_base and config_api_base.strip():
|
||||
return config_api_base
|
||||
|
||||
if provider:
|
||||
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||
if key_default:
|
||||
return key_default
|
||||
|
||||
if provider_prefix:
|
||||
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||
if prefix_default:
|
||||
return prefix_default
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROVIDER_DEFAULT_API_BASE",
|
||||
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||
"resolve_api_base",
|
||||
]
|
||||
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Capability resolution shared by chat / image / vision call sites.
|
||||
|
||||
Why this exists
|
||||
---------------
|
||||
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
|
||||
single, authoritative answer to one question: *can this chat config accept
|
||||
``image_url`` content blocks?* Without it, the new-chat selector can't badge
|
||||
incompatible models and the streaming task can't fail fast with a friendly
|
||||
error before sending an image to a text-only provider.
|
||||
|
||||
Two functions, two intents:
|
||||
|
||||
- :func:`derive_supports_image_input` — best-effort *True* for catalog and
|
||||
UI surfacing. Default-allow: an unknown / unmapped model is treated as
|
||||
capable so we never lock the user out of a freshly added or
|
||||
third-party-hosted vision model.
|
||||
|
||||
- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
|
||||
task's safety net. Returns True only when LiteLLM's model map *explicitly*
|
||||
sets ``supports_vision=False`` (or its bare-name variant does). Anything
|
||||
else — missing key, lookup exception, ``supports_vision=True`` — returns
|
||||
False so the request flows through to the provider.
|
||||
|
||||
Implementation rule: only public LiteLLM symbols
|
||||
------------------------------------------------
|
||||
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
|
||||
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
|
||||
across releases. The private ``_is_explicitly_disabled_factory`` and
|
||||
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
|
||||
can't silently break us.
|
||||
|
||||
Why the previous round's strict YAML opt-in flag failed
|
||||
-------------------------------------------------------
|
||||
``supports_image_input: false`` was the YAML loader's setdefault. Operators
|
||||
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
|
||||
YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
|
||||
False and the streaming gate rejected every image turn. Sourcing capability
|
||||
from LiteLLM's authoritative model map (which already says
|
||||
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider-name → LiteLLM model-prefix map.
|
||||
#
|
||||
# Owned here because ``app.services.provider_capabilities`` is the
|
||||
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
|
||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||
# map there directly would re-introduce the
|
||||
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
||||
# app.config`` cycle that prompted the move.
|
||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
def _candidate_model_strings(
|
||||
*,
|
||||
provider: str | None,
|
||||
model_name: str | None,
|
||||
base_model: str | None,
|
||||
custom_provider: str | None,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
|
||||
|
||||
LiteLLM's capability lookup is keyed by ``model`` + (optional)
|
||||
``custom_llm_provider``. Different config sources give us different
|
||||
levels of detail, so we try the most-specific keys first and fall back
|
||||
to bare model names so unannotated entries (e.g. an Azure deployment
|
||||
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
|
||||
map. Order matters — the first lookup that returns a definitive answer
|
||||
wins for both helpers.
|
||||
"""
|
||||
candidates: list[tuple[str, str | None]] = []
|
||||
seen: set[tuple[str, str | None]] = set()
|
||||
|
||||
def _add(model: str | None, llm_provider: str | None) -> None:
|
||||
if not model:
|
||||
return
|
||||
key = (model, llm_provider)
|
||||
if key in seen:
|
||||
return
|
||||
seen.add(key)
|
||||
candidates.append(key)
|
||||
|
||||
provider_prefix: str | None = None
|
||||
if provider:
|
||||
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
|
||||
if custom_provider:
|
||||
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
|
||||
provider_prefix = custom_provider
|
||||
|
||||
primary_model = base_model or model_name
|
||||
bare_model = model_name
|
||||
|
||||
# Most-specific first: provider-prefixed identifier with explicit
|
||||
# custom_llm_provider so LiteLLM won't have to guess the provider via
|
||||
# ``get_llm_provider``.
|
||||
if primary_model and provider_prefix:
|
||||
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
|
||||
if "/" in primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
else:
|
||||
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
|
||||
|
||||
# Bare base_model (or model_name) with provider hint — handles entries
|
||||
# the upstream map keys without a provider prefix (most ``gpt-*`` and
|
||||
# ``claude-*`` entries do this).
|
||||
if primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
|
||||
# Fallback to model_name when base_model differs (e.g. an Azure
|
||||
# deployment whose model_name is the deployment id but base_model is the
|
||||
# canonical OpenAI sku).
|
||||
if bare_model and bare_model != primary_model:
|
||||
if provider_prefix and "/" not in bare_model:
|
||||
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
|
||||
_add(bare_model, provider_prefix)
|
||||
_add(bare_model, None)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def derive_supports_image_input(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
openrouter_input_modalities: Iterable[str] | None = None,
|
||||
) -> bool:
|
||||
"""Best-effort capability flag for the new-chat selector and catalog.
|
||||
|
||||
Resolution order (first definitive answer wins):
|
||||
|
||||
1. ``openrouter_input_modalities`` (when provided as a non-empty
|
||||
iterable). OpenRouter exposes ``architecture.input_modalities`` per
|
||||
model and that's the authoritative source for OR dynamic configs.
|
||||
2. ``litellm.supports_vision`` against each candidate identifier from
|
||||
:func:`_candidate_model_strings`. Returns True as soon as any
|
||||
candidate confirms vision support.
|
||||
3. Default ``True`` — the conservative-allow stance. An unknown /
|
||||
newly-added / third-party-hosted model is *not* pre-judged. The
|
||||
streaming safety net (:func:`is_known_text_only_chat_model`) is the
|
||||
only place a False ever blocks; everywhere else, a False here would
|
||||
just hide a usable model from the user.
|
||||
|
||||
Returns:
|
||||
True if the model can plausibly accept image input, False only when
|
||||
OpenRouter explicitly says it can't.
|
||||
"""
|
||||
if openrouter_input_modalities is not None:
|
||||
modalities = list(openrouter_input_modalities)
|
||||
if modalities:
|
||||
return "image" in modalities
|
||||
# Empty list explicitly published by OR — treat as "no image".
|
||||
return False
|
||||
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
if litellm.supports_vision(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.supports_vision raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
|
||||
return True
|
||||
|
||||
|
||||
def is_known_text_only_chat_model(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
) -> bool:
|
||||
"""Strict opt-out probe for the streaming-task safety net.
|
||||
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False`` for at least one candidate identifier. Missing
|
||||
key, lookup exception, or ``supports_vision=True`` all return False so
|
||||
the streaming task lets the request through. This is the inverse-default
|
||||
of :func:`derive_supports_image_input`.
|
||||
|
||||
Why two functions
|
||||
-----------------
|
||||
The selector wants "show me everything that's plausibly capable" —
|
||||
default-allow. The safety net wants "block only when I'm certain it
|
||||
can't" — default-pass. Mixing the two intents in a single function
|
||||
leads to the regression we're fixing here.
|
||||
"""
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
info = litellm.get_model_info(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.get_model_info raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
|
||||
# may be missing, None, True, or False. We only fire on explicit
|
||||
# False — None / missing / True all mean "don't block".
|
||||
try:
|
||||
value = info.get("supports_vision") # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
value = None
|
||||
if value is False:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"derive_supports_image_input",
|
||||
"is_known_text_only_chat_model",
|
||||
]
|
||||
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
|
||||
|
||||
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
|
||||
indexing pipeline (file processors, connector indexers, etl pipeline) can
|
||||
keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
|
||||
— without threading ``user_id`` through every parser. The wrapper looks like
|
||||
a chat model from the outside; on the inside it routes each call through
|
||||
``billable_call`` so the user's premium credit pool is reserved → finalized
|
||||
or released, and a ``TokenUsage`` audit row is written.
|
||||
|
||||
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
|
||||
need quota enforcement) so this class only ever wraps premium configs.
|
||||
|
||||
Why a wrapper instead of plumbing ``user_id`` through every caller:
|
||||
|
||||
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
|
||||
Dropbox, local-folder, file-processor, ETL pipeline) each calling
|
||||
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
|
||||
invasive, error-prone, and easy for a future indexer to forget.
|
||||
* Per the design (issue M), we always debit the *search-space owner*, not
|
||||
the triggering user, so ``user_id`` is fully derivable from the search
|
||||
space the caller is already operating on. The wrapper captures it once
|
||||
at construction time.
|
||||
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
|
||||
call run this coroutine"; subclassing isn't safe across versions because
|
||||
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
|
||||
Composition via attribute proxying (``__getattr__``) is robust to
|
||||
upstream changes — every method other than ``ainvoke`` falls through to
|
||||
the inner LLM unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.billable_calls import QuotaInsufficientError, billable_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaCheckedVisionLLM:
|
||||
"""Composition wrapper around a langchain chat model that enforces
|
||||
premium credit quota on every ``ainvoke``.
|
||||
|
||||
Anything other than ``ainvoke`` is forwarded to the inner model so
|
||||
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
|
||||
still work — they simply bypass quota enforcement, which is fine
|
||||
because the indexing pipeline only ever calls ``ainvoke`` today.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_llm: Any,
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
usage_type: str = "vision_extraction",
|
||||
) -> None:
|
||||
self._inner = inner_llm
|
||||
self._user_id = user_id
|
||||
self._search_space_id = search_space_id
|
||||
self._billing_tier = billing_tier
|
||||
self._base_model = base_model
|
||||
self._quota_reserve_tokens = quota_reserve_tokens
|
||||
self._usage_type = usage_type
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Proxied async invoke that runs the underlying call inside
|
||||
``billable_call``.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when the user has exhausted their
|
||||
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
|
||||
catches this and falls back to the document parser.
|
||||
"""
|
||||
async with billable_call(
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
billing_tier=self._billing_tier,
|
||||
base_model=self._base_model,
|
||||
quota_reserve_tokens=self._quota_reserve_tokens,
|
||||
usage_type=self._usage_type,
|
||||
call_details={"model": self._base_model},
|
||||
):
|
||||
return await self._inner.ainvoke(input, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward everything else (``invoke``, ``astream``, ``bind``,
|
||||
``with_structured_output``, …) to the inner model.
|
||||
|
||||
``__getattr__`` is only consulted when the attribute is *not*
|
||||
already found on the proxy, which is exactly the contract we
|
||||
want — methods we override stay on the proxy, the rest fall
|
||||
through.
|
||||
"""
|
||||
return getattr(self._inner, name)
|
||||
|
||||
|
||||
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
|
||||
|
|
@ -22,6 +22,71 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call reservation estimator (USD micro-units)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
|
||||
# request, and so models without registered pricing reserve at least
|
||||
# something while the call runs (debited 0 at finalize anyway when their
|
||||
# cost can't be resolved).
|
||||
_QUOTA_MIN_RESERVE_MICROS = 100
|
||||
|
||||
|
||||
def estimate_call_reserve_micros(
|
||||
*,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
) -> int:
|
||||
"""Return the number of micro-USD to reserve for one premium call.
|
||||
|
||||
Computes a worst-case upper bound from LiteLLM's per-token pricing
|
||||
table:
|
||||
|
||||
reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
|
||||
|
||||
so the math scales with model cost — Claude Opus + 4K reserve_tokens
|
||||
naturally reserves ≈ $0.36, while a cheap model reserves only a few
|
||||
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
|
||||
so a misconfigured "$1000/M" model can't lock the whole balance on
|
||||
one call.
|
||||
|
||||
If ``litellm.get_model_info`` raises (model unknown) we fall back to
|
||||
the floor — 100 micros / $0.0001 — which is enough to gate a sane
|
||||
request without over-reserving for a model whose pricing the
|
||||
operator hasn't declared yet.
|
||||
"""
|
||||
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
if reserve_tokens <= 0:
|
||||
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
info = get_model_info(base_model) if base_model else {}
|
||||
input_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
output_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[quota_reserve] cost lookup failed for base_model=%s: %s",
|
||||
base_model,
|
||||
exc,
|
||||
)
|
||||
input_cost = 0.0
|
||||
output_cost = 0.0
|
||||
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
return _QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
reserve_usd = reserve_tokens * (input_cost + output_cost)
|
||||
reserve_micros = round(reserve_usd * 1_000_000)
|
||||
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
|
||||
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
|
||||
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
|
||||
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
|
||||
return reserve_micros
|
||||
|
||||
|
||||
class QuotaScope(StrEnum):
|
||||
ANONYMOUS = "anonymous"
|
||||
PREMIUM = "premium"
|
||||
|
|
@ -444,8 +509,16 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
reserve_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
|
||||
premium credit balance.
|
||||
|
||||
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
|
||||
all in micro-USD on this code path; callers (chat stream,
|
||||
token-status route, FE display) convert to dollars by dividing
|
||||
by 1_000_000.
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -465,11 +538,11 @@ class TokenQuotaService:
|
|||
limit=0,
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
|
||||
effective = used + reserved + reserve_tokens
|
||||
effective = used + reserved + reserve_micros
|
||||
if effective > limit:
|
||||
remaining = max(0, limit - used - reserved)
|
||||
await db_session.rollback()
|
||||
|
|
@ -482,10 +555,10 @@ class TokenQuotaService:
|
|||
remaining=remaining,
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
||||
user.premium_credit_micros_reserved = reserved + reserve_micros
|
||||
await db_session.commit()
|
||||
|
||||
new_reserved = reserved + reserve_tokens
|
||||
new_reserved = reserved + reserve_micros
|
||||
remaining = max(0, limit - used - new_reserved)
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
||||
|
|
@ -510,9 +583,12 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
reserved_tokens: int,
|
||||
actual_micros: int,
|
||||
reserved_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Settle the reservation: release ``reserved_micros`` and debit
|
||||
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -529,16 +605,18 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
user.premium_credit_micros_used = (
|
||||
user.premium_credit_micros_used + actual_micros
|
||||
)
|
||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
@ -562,8 +640,13 @@ class TokenQuotaService:
|
|||
async def premium_release(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
reserved_tokens: int,
|
||||
reserved_micros: int,
|
||||
) -> None:
|
||||
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
|
||||
|
||||
Used when a request fails before finalize (so the reservation
|
||||
doesn't leak credit).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -576,8 +659,8 @@ class TokenQuotaService:
|
|||
.scalar_one_or_none()
|
||||
)
|
||||
if user is not None:
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
|
|
@ -598,9 +681,9 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost_micros": 0,
|
||||
},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
entry["cost_micros"] += c.cost_micros
|
||||
return by_model
|
||||
|
||||
@property
|
||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
|||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_cost_micros(self) -> int:
|
||||
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||
|
||||
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||
provider cost (in micro-USD) from the user's premium credit
|
||||
balance. ``cost_micros`` per call is captured by
|
||||
``TokenTrackingCallback.async_log_success_event`` from
|
||||
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||
with multiple fallback paths so OpenRouter dynamic models and
|
||||
custom Azure deployments still bill correctly when our
|
||||
``pricing_registration`` ran at startup.
|
||||
"""
|
||||
return sum(c.cost_micros for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
|||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
"""Create a fresh accumulator for the current async context and return it.
|
||||
|
||||
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||
avoids leaking call records across reservations (issue B).
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
for the duration of the ``async with`` block, then *resets* the
|
||||
ContextVar to its previous value on exit.
|
||||
|
||||
This is the safe primitive for per-call billable operations
|
||||
(image generation, vision LLM extraction, podcasts) that may run
|
||||
inside an outer chat turn or be called sequentially from the same
|
||||
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||
:func:`start_turn` does) would leak the inner accumulator into the
|
||||
outer scope, causing the outer chat turn to debit cost twice.
|
||||
|
||||
Usage::
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await llm.ainvoke(...)
|
||||
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||
# Outer accumulator (if any) is restored here.
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
token = _turn_accumulator.set(acc)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||
id(acc),
|
||||
token,
|
||||
)
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||
id(acc),
|
||||
len(acc.calls),
|
||||
acc.total_cost_micros,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cost_usd(
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
is_image: bool = False,
|
||||
) -> float:
|
||||
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||
|
||||
Tries four sources in priority order and returns the first that
|
||||
yields a positive number; returns 0.0 if all four fail (the call
|
||||
will then debit nothing from the user's balance — fail-safe).
|
||||
|
||||
Sources:
|
||||
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||
field, populated for ``Router.acompletion`` since PR #12500.
|
||||
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||
exposed on the response itself.
|
||||
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||
— recompute from the response and LiteLLM's pricing table.
|
||||
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||
— manual fallback for OpenRouter/custom-Azure models that
|
||||
only resolve via aliases registered by
|
||||
``pricing_registration`` at startup. **Skipped for image
|
||||
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||
and would raise; the cost map for image-gen lives in different
|
||||
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||
"""
|
||||
cost = kwargs.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||
if isinstance(hidden, dict):
|
||||
cost = hidden.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
if is_image:
|
||||
# Image-gen path: OpenRouter's image responses can omit
|
||||
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||
# then *raises* (no cost map for OpenRouter image models).
|
||||
# Bail out with a warning rather than falling through to
|
||||
# cost_per_token (which is also incompatible with ImageResponse).
|
||||
logger.warning(
|
||||
"[TokenTracking] completion_cost failed for image model=%s "
|
||||
"(provider may have omitted usage.cost). Debiting 0. "
|
||||
"Cause: %s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
return 0.0
|
||||
logger.debug(
|
||||
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
if is_image:
|
||||
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||
# the function is documented chat-only.
|
||||
return 0.0
|
||||
|
||||
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
try:
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
value = float(prompt_cost) + float(completion_cost)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
# Detect image generation responses — they have a different usage
|
||||
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||
# different cost-extraction path. We probe by class name to avoid a
|
||||
# hard import dependency on litellm internals.
|
||||
response_cls = type(response_obj).__name__
|
||||
is_image = response_cls == "ImageResponse"
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug(
|
||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
if is_image:
|
||||
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||
# (not prompt_tokens/completion_tokens). Several providers
|
||||
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||
# passes through `input_tokens` from the prompt but no
|
||||
# completion); fall through gracefully to 0.
|
||||
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
total_tokens = (
|
||||
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||
)
|
||||
call_kind = "image_generation"
|
||||
else:
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
call_kind = "chat"
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
cost_usd = _extract_cost_usd(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
is_image=is_image,
|
||||
)
|
||||
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||
|
||||
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
logger.warning(
|
||||
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||
"for image generation).",
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
call_kind,
|
||||
)
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||
model,
|
||||
call_kind,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_usd,
|
||||
cost_micros,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
|||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cost_micros: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
|||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
|
@ -108,10 +110,11 @@ class VisionLLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -120,8 +123,13 @@ class VisionLLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,25 @@
|
|||
"""Celery tasks package."""
|
||||
"""Celery tasks package.
|
||||
|
||||
Also hosts the small helpers every async celery task should use to
|
||||
spin up its event loop. See :func:`run_async_celery_task` for the
|
||||
canonical pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_celery_engine = None
|
||||
_celery_session_maker = None
|
||||
|
||||
|
|
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
|
|||
_celery_engine, expire_on_commit=False
|
||||
)
|
||||
return _celery_session_maker
|
||||
|
||||
|
||||
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Drop the shared ``app.db.engine`` connection pool synchronously.
|
||||
|
||||
The shared engine (used by ``shielded_async_session`` and most
|
||||
routes / services) is a module-level singleton with a real pool.
|
||||
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
|
||||
connections cache a reference to whichever loop opened them. When
|
||||
a subsequent task's loop pulls a stale connection from the pool,
|
||||
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
|
||||
self._write_fut = self._loop._proactor.send(self._sock, data)
|
||||
|
||||
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
|
||||
coroutine that can never run because its loop is gone.
|
||||
|
||||
Disposing the engine forces the pool to drop every cached
|
||||
connection so the next checkout opens a fresh one on the current
|
||||
loop. Safe to call from a task's finally block; failure is logged
|
||||
but never propagated.
|
||||
"""
|
||||
try:
|
||||
from app.db import engine as shared_engine
|
||||
|
||||
loop.run_until_complete(shared_engine.dispose())
|
||||
except Exception:
|
||||
logger.warning("Shared DB engine dispose() failed", exc_info=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
|
||||
"""Run an async coroutine inside a fresh event loop with proper
|
||||
DB-engine cleanup.
|
||||
|
||||
This is the canonical entry point for every async celery task.
|
||||
It performs three responsibilities that were previously copy-pasted
|
||||
(incorrectly) across each task module:
|
||||
|
||||
1. Create a fresh ``asyncio`` loop and install it on the current
|
||||
thread (celery's ``--pool=solo`` runs every task on the main
|
||||
thread, but other pool types don't).
|
||||
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
|
||||
any stale connections left over from a previous task's loop
|
||||
are dropped — defends against tasks that crashed without
|
||||
cleaning up.
|
||||
3. Dispose the shared engine AFTER the task runs so the
|
||||
connections we opened on this loop are released before the
|
||||
loop closes (avoids ``coroutine 'Connection._cancel' was
|
||||
never awaited`` warnings and the next-task hang).
|
||||
|
||||
Use as::
|
||||
|
||||
@celery_app.task(name="my_task", bind=True)
|
||||
def my_task(self, *args):
|
||||
return run_async_celery_task(lambda: _my_task_impl(*args))
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Defense-in-depth: prior task may have crashed before
|
||||
# disposing. Idempotent — no-op if pool is already empty.
|
||||
_dispose_shared_db_engine(loop)
|
||||
return loop.run_until_complete(coro_factory())
|
||||
finally:
|
||||
# Drop any connections this task opened so they don't leak
|
||||
# into the next task's loop.
|
||||
_dispose_shared_db_engine(loop)
|
||||
with contextlib.suppress(Exception):
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
with contextlib.suppress(Exception):
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_celery_session_maker",
|
||||
"run_async_celery_task",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import traceback
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,22 +49,15 @@ def index_notion_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Notion pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_notion_pages(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_notion_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_notion_pages(
|
||||
|
|
@ -95,19 +88,11 @@ def index_github_repos_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index GitHub repositories."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_github_repos(
|
||||
|
|
@ -138,19 +123,11 @@ def index_confluence_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Confluence pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_confluence_pages(
|
||||
|
|
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Calendar events."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_calendar_events(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_calendar_events(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_google_calendar_events(
|
||||
|
|
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Gmail messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_gmail_messages(
|
||||
|
|
@ -269,22 +231,14 @@ def index_google_drive_files_task(
|
|||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||
):
|
||||
"""Celery task to index Google Drive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_drive_files(
|
||||
|
|
@ -317,22 +271,14 @@ def index_onedrive_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index OneDrive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_onedrive_files(
|
||||
|
|
@ -365,22 +311,14 @@ def index_dropbox_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index Dropbox folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_dropbox_files(
|
||||
|
|
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Elasticsearch documents."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_elasticsearch_documents(
|
||||
|
|
@ -457,22 +387,15 @@ def index_crawled_urls_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Web page Urls."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_crawled_urls(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_crawled_urls(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_crawled_urls(
|
||||
|
|
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index BookStack pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_bookstack_pages(
|
||||
|
|
@ -546,19 +461,11 @@ def index_composio_connector_task(
|
|||
end_date: str | None,
|
||||
):
|
||||
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_composio_connector(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.db import Document
|
|||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
|
|||
document_id: ID of document to reindex
|
||||
user_id: ID of user who edited the document
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reindex_document(document_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
|
||||
|
||||
|
||||
async def _reindex_document(document_id: int, user_id: str):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.celery_app import celery_app
|
|||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
from app.tasks.connector_indexers.local_folder_indexer import (
|
||||
index_local_folder,
|
||||
index_uploaded_files,
|
||||
|
|
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
|
|||
)
|
||||
def delete_document_task(self, document_id: int):
|
||||
"""Celery task to delete a document and its chunks in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_document_background(document_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _delete_document_background(document_id))
|
||||
|
||||
|
||||
async def _delete_document_background(document_id: int) -> None:
|
||||
|
|
@ -153,14 +148,9 @@ def delete_folder_documents_task(
|
|||
folder_subtree_ids: list[int] | None = None,
|
||||
):
|
||||
"""Celery task to delete documents first, then the folder rows."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_folder_documents(
|
||||
|
|
@ -209,12 +199,9 @@ async def _delete_folder_documents(
|
|||
)
|
||||
def delete_search_space_task(self, search_space_id: int):
|
||||
"""Celery task to delete a search space and heavy child rows in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_search_space_background(search_space_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_search_space_background(search_space_id)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_search_space_background(search_space_id: int) -> None:
|
||||
|
|
@ -269,18 +256,11 @@ def process_extension_document_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
# Create a new event loop for this task
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_extension_document(
|
||||
|
|
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _process_youtube_video(url, search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
||||
|
|
@ -573,12 +549,9 @@ def process_file_upload_task(
|
|||
except Exception as e:
|
||||
logger.warning(f"[process_file_upload] Could not get file size: {e}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
)
|
||||
logger.info(
|
||||
f"[process_file_upload] Task completed successfully for: {filename}"
|
||||
|
|
@ -589,8 +562,6 @@ def process_file_upload_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _process_file_upload(
|
||||
|
|
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
|
|||
"File may have been removed before syncing could start."
|
||||
)
|
||||
# Mark document as failed since file is missing
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
run_async_celery_task(
|
||||
lambda: _mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
return
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_with_document(
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_with_document(
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
|
|
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_document_failed(document_id: int, reason: str):
|
||||
|
|
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
|
|||
search_space_id: ID of the search space
|
||||
connector_id: ID of the Circleback connector (for deletion support)
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_circleback_meeting(
|
||||
|
|
@ -1291,25 +1246,19 @@ def index_local_folder_task(
|
|||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_local_folder_async(
|
||||
|
|
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
|
|||
processing_mode: str = "basic",
|
||||
):
|
||||
"""Celery task to index files uploaded from the desktop app."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_uploaded_folder_files_async(
|
||||
|
|
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
|
|||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||
"""Full AI sort for all documents in a search space."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_search_space_async(search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||
|
|
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
|||
)
|
||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||
"""Incremental AI sort for a single document after indexing."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue