Merge pull request #966 from MODSetter/dev

feat: HITL Workflows and Fixing Real-Time Sync
This commit is contained in:
Rohan Verma 2026-03-25 00:10:15 -07:00 committed by GitHub
commit 8227d1852f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
320 changed files with 33857 additions and 19630 deletions

View file

@ -57,7 +57,7 @@ jobs:
working-directory: surfsense_web working-directory: surfsense_web
env: env:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }} NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }}
NEXT_PUBLIC_ELECTRIC_URL: ${{ vars.NEXT_PUBLIC_ELECTRIC_URL }} NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }} NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }} NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}

View file

@ -164,8 +164,7 @@ jobs:
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }} ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }} ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }} ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__' || '' }} ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__' || '' }}
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }} ${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
- name: Export digest - name: Export digest

View file

@ -35,7 +35,7 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# BACKEND_PORT=8929 # BACKEND_PORT=8929
# FRONTEND_PORT=3929 # FRONTEND_PORT=3929
# ELECTRIC_PORT=5929 # ZERO_CACHE_PORT=5929
# SEARXNG_PORT=8888 # SEARXNG_PORT=8888
# FLOWER_PORT=5555 # FLOWER_PORT=5555
@ -58,7 +58,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL # NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
# NEXT_PUBLIC_ETL_SERVICE=DOCLING # NEXT_PUBLIC_ETL_SERVICE=DOCLING
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted # NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
# NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Custom Domain / Reverse Proxy # Custom Domain / Reverse Proxy
@ -71,8 +70,35 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# NEXT_FRONTEND_URL=https://app.yourdomain.com # NEXT_FRONTEND_URL=https://app.yourdomain.com
# BACKEND_URL=https://api.yourdomain.com # BACKEND_URL=https://api.yourdomain.com
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com # NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
# NEXT_PUBLIC_ELECTRIC_URL=https://electric.yourdomain.com # NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com
# ------------------------------------------------------------------------------
# Zero-cache (real-time sync)
# ------------------------------------------------------------------------------
# Defaults work out of the box for Docker deployments.
# Change ZERO_ADMIN_PASSWORD for security in production.
# ZERO_ADMIN_PASSWORD=surfsense-zero-admin
# Full override for the Zero → Postgres connection URLs.
# Leave commented out to use the Docker-managed `db` container (default).
# ZERO_UPSTREAM_DB=postgresql://surfsense:surfsense@db:5432/surfsense
# ZERO_CVR_DB=postgresql://surfsense:surfsense@db:5432/surfsense
# ZERO_CHANGE_DB=postgresql://surfsense:surfsense@db:5432/surfsense
# ZERO_QUERY_URL: where zero-cache forwards query requests for resolution.
# ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
# skip its own JWT verification and let the app endpoints handle auth instead.
# The mutate endpoint is a no-op that returns an empty response.
# Default: Docker service networking (http://frontend:3000/api/zero/...).
# Override when running the frontend outside Docker:
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
# Override for custom domain:
# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query
# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Database (defaults work out of the box, change for security) # Database (defaults work out of the box, change for security)
@ -101,19 +127,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# Supports TLS: rediss://:password@host:6380/0 # Supports TLS: rediss://:password@host:6380/0
# REDIS_URL=redis://redis:6379/0 # REDIS_URL=redis://redis:6379/0
# ------------------------------------------------------------------------------
# Electric SQL (real-time sync credentials)
# ------------------------------------------------------------------------------
# These must match on the db, backend, and electric services.
# Change for security; defaults work out of the box.
# ELECTRIC_DB_USER=electric
# ELECTRIC_DB_PASSWORD=electric_password
# Full override for the Electric → Postgres connection URL.
# Leave commented out to use the Docker-managed `db` container (default).
# Uncomment and set `db` to `host.docker.internal` when pointing Electric at a local Postgres instance (e.g. Postgres.app on macOS):
# ELECTRIC_DATABASE_URL=postgresql://electric:electric_password@db:5432/surfsense?sslmode=disable
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# TTS & STT (Text-to-Speech / Speech-to-Text) # TTS & STT (Text-to-Speech / Speech-to-Text)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------

View file

@ -18,13 +18,10 @@ services:
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro - ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
environment: environment:
- POSTGRES_USER=${DB_USER:-postgres} - POSTGRES_USER=${DB_USER:-postgres}
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres} - POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
- POSTGRES_DB=${DB_NAME:-surfsense} - POSTGRES_DB=${DB_NAME:-surfsense}
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
command: postgres -c config_file=/etc/postgresql/postgresql.conf command: postgres -c config_file=/etc/postgresql/postgresql.conf
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"] test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
@ -91,8 +88,6 @@ services:
- UNSTRUCTURED_HAS_PATCHED_LOOP=1 - UNSTRUCTURED_HAS_PATCHED_LOOP=1
- LANGCHAIN_TRACING_V2=false - LANGCHAIN_TRACING_V2=false
- LANGSMITH_TRACING=false - LANGSMITH_TRACING=false
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- AUTH_TYPE=${AUTH_TYPE:-LOCAL} - AUTH_TYPE=${AUTH_TYPE:-LOCAL}
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000} - NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080} - SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
@ -130,8 +125,6 @@ services:
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0} - REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
- CELERY_TASK_DEFAULT_QUEUE=surfsense - CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app - PYTHONPATH=/app
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080} - SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
- SERVICE_ROLE=worker - SERVICE_ROLE=worker
depends_on: depends_on:
@ -176,20 +169,28 @@ services:
# - redis # - redis
# - celery_worker # - celery_worker
electric: zero-cache:
image: electricsql/electric:1.4.10 image: rocicorp/zero:0.26.2
ports: ports:
- "${ELECTRIC_PORT:-5133}:3000" - "${ZERO_CACHE_PORT:-4848}:4848"
extra_hosts:
- "host.docker.internal:host-gateway"
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
environment: environment:
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}} - ZERO_UPSTREAM_DB=${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
- ELECTRIC_INSECURE=true - ZERO_CVR_DB=${ZERO_CVR_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
- ELECTRIC_WRITE_TO_PG_MODE=direct - ZERO_CHANGE_DB=${ZERO_CHANGE_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
- ZERO_REPLICA_FILE=/data/zero.db
- ZERO_ADMIN_PASSWORD=${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
volumes:
- zero_cache_data:/data
restart: unless-stopped restart: unless-stopped
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"] test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@ -201,8 +202,7 @@ services:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000} NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL} NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING} NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133} NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}}
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted} NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
ports: ports:
- "${FRONTEND_PORT:-3000}:3000" - "${FRONTEND_PORT:-3000}:3000"
@ -211,7 +211,7 @@ services:
depends_on: depends_on:
backend: backend:
condition: service_healthy condition: service_healthy
electric: zero-cache:
condition: service_healthy condition: service_healthy
volumes: volumes:
@ -223,3 +223,5 @@ volumes:
name: surfsense-dev-redis name: surfsense-dev-redis
shared_temp: shared_temp:
name: surfsense-dev-shared-temp name: surfsense-dev-shared-temp
zero_cache_data:
name: surfsense-dev-zero-cache

View file

@ -15,13 +15,10 @@ services:
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro - ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
environment: environment:
POSTGRES_USER: ${DB_USER:-surfsense} POSTGRES_USER: ${DB_USER:-surfsense}
POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense} POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense}
POSTGRES_DB: ${DB_NAME:-surfsense} POSTGRES_DB: ${DB_NAME:-surfsense}
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
command: postgres -c config_file=/etc/postgresql/postgresql.conf command: postgres -c config_file=/etc/postgresql/postgresql.conf
restart: unless-stopped restart: unless-stopped
healthcheck: healthcheck:
@ -72,8 +69,6 @@ services:
PYTHONPATH: /app PYTHONPATH: /app
UVICORN_LOOP: asyncio UVICORN_LOOP: asyncio
UNSTRUCTURED_HAS_PATCHED_LOOP: "1" UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}} NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}}
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
# Daytona Sandbox uncomment and set credentials to enable cloud code execution # Daytona Sandbox uncomment and set credentials to enable cloud code execution
@ -112,8 +107,6 @@ services:
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0} REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
CELERY_TASK_DEFAULT_QUEUE: surfsense CELERY_TASK_DEFAULT_QUEUE: surfsense
PYTHONPATH: /app PYTHONPATH: /app
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
SERVICE_ROLE: worker SERVICE_ROLE: worker
depends_on: depends_on:
@ -165,20 +158,28 @@ services:
# - celery_worker # - celery_worker
# restart: unless-stopped # restart: unless-stopped
electric: zero-cache:
image: electricsql/electric:1.4.10 image: rocicorp/zero:0.26.2
ports: ports:
- "${ELECTRIC_PORT:-5929}:3000" - "${ZERO_CACHE_PORT:-5929}:4848"
extra_hosts:
- "host.docker.internal:host-gateway"
environment: environment:
DATABASE_URL: ${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}} ZERO_UPSTREAM_DB: ${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
ELECTRIC_INSECURE: "true" ZERO_CVR_DB: ${ZERO_CVR_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
ELECTRIC_WRITE_TO_PG_MODE: direct ZERO_CHANGE_DB: ${ZERO_CHANGE_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
ZERO_REPLICA_FILE: /data/zero.db
ZERO_ADMIN_PASSWORD: ${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
ZERO_QUERY_URL: ${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
ZERO_MUTATE_URL: ${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
volumes:
- zero_cache_data:/data
restart: unless-stopped restart: unless-stopped
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"] test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@ -189,17 +190,16 @@ services:
- "${FRONTEND_PORT:-3929}:3000" - "${FRONTEND_PORT:-3929}:3000"
environment: environment:
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}} NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}}
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:${ELECTRIC_PORT:-5929}} NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}}
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL} NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING} NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
labels: labels:
- "com.centurylinklabs.watchtower.enable=true" - "com.centurylinklabs.watchtower.enable=true"
depends_on: depends_on:
backend: backend:
condition: service_healthy condition: service_healthy
electric: zero-cache:
condition: service_healthy condition: service_healthy
restart: unless-stopped restart: unless-stopped
@ -210,3 +210,5 @@ volumes:
name: surfsense-redis name: surfsense-redis
shared_temp: shared_temp:
name: surfsense-shared-temp name: surfsense-shared-temp
zero_cache_data:
name: surfsense-zero-cache

View file

@ -1,11 +1,11 @@
# PostgreSQL configuration for Electric SQL # PostgreSQL configuration for SurfSense
# This file is mounted into the PostgreSQL container # This file is mounted into the PostgreSQL container
listen_addresses = '*' listen_addresses = '*'
max_connections = 200 max_connections = 200
shared_buffers = 256MB shared_buffers = 256MB
# Enable logical replication (required for Electric SQL) # Enable logical replication (required for Zero-cache real-time sync)
wal_level = logical wal_level = logical
max_replication_slots = 10 max_replication_slots = 10
max_wal_senders = 10 max_wal_senders = 10

View file

@ -1,38 +0,0 @@
#!/bin/sh
# Creates the Electric SQL replication user on first DB initialization.
# Idempotent — safe to run alongside Alembic migration 66.
set -e
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
echo "Creating Electric SQL replication user: $ELECTRIC_DB_USER"
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
DO \$\$
BEGIN
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '$ELECTRIC_DB_USER') THEN
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
END IF;
END
\$\$;
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
GRANT CREATE ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
DO \$\$
BEGIN
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
CREATE PUBLICATION electric_publication_default;
END IF;
END
\$\$;
EOSQL
echo "Electric SQL user '$ELECTRIC_DB_USER' and publication created successfully"

View file

@ -109,7 +109,6 @@ $Files = @(
@{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" } @{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" }
@{ Src = "docker/.env.example"; Dest = ".env.example" } @{ Src = "docker/.env.example"; Dest = ".env.example" }
@{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" } @{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" }
@{ Src = "docker/scripts/init-electric-user.sh"; Dest = "scripts/init-electric-user.sh" }
@{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" } @{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" }
@{ Src = "docker/searxng/settings.yml"; Dest = "searxng/settings.yml" } @{ Src = "docker/searxng/settings.yml"; Dest = "searxng/settings.yml" }
@{ Src = "docker/searxng/limiter.toml"; Dest = "searxng/limiter.toml" } @{ Src = "docker/searxng/limiter.toml"; Dest = "searxng/limiter.toml" }

View file

@ -108,7 +108,6 @@ FILES=(
"docker/docker-compose.yml:docker-compose.yml" "docker/docker-compose.yml:docker-compose.yml"
"docker/.env.example:.env.example" "docker/.env.example:.env.example"
"docker/postgresql.conf:postgresql.conf" "docker/postgresql.conf:postgresql.conf"
"docker/scripts/init-electric-user.sh:scripts/init-electric-user.sh"
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh" "docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
"docker/searxng/settings.yml:searxng/settings.yml" "docker/searxng/settings.yml:searxng/settings.yml"
"docker/searxng/limiter.toml:searxng/limiter.toml" "docker/searxng/limiter.toml:searxng/limiter.toml"
@ -122,7 +121,6 @@ for entry in "${FILES[@]}"; do
|| error "Failed to download ${dest}. Check your internet connection and try again." || error "Failed to download ${dest}. Check your internet connection and try again."
done done
chmod +x "${INSTALL_DIR}/scripts/init-electric-user.sh"
chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh" chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh"
success "All files downloaded to ${INSTALL_DIR}/" success "All files downloaded to ${INSTALL_DIR}/"

View file

@ -17,10 +17,6 @@ REDIS_APP_URL=redis://localhost:6379/0
# Only uncomment if running the backend outside Docker (e.g. uvicorn on host). # Only uncomment if running the backend outside Docker (e.g. uvicorn on host).
# SEARXNG_DEFAULT_HOST=http://localhost:8888 # SEARXNG_DEFAULT_HOST=http://localhost:8888
#Electric(for migrations only)
ELECTRIC_DB_USER=electric
ELECTRIC_DB_PASSWORD=electric_password
# Periodic task interval # Periodic task interval
# # Run every minute (default) # # Run every minute (default)
# SCHEDULE_CHECKER_INTERVAL=1m # SCHEDULE_CHECKER_INTERVAL=1m
@ -104,7 +100,8 @@ TEAMS_CLIENT_ID=your_teams_client_id_here
TEAMS_CLIENT_SECRET=your_teams_client_secret_here TEAMS_CLIENT_SECRET=your_teams_client_secret_here
TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
#Composio Coonnector # Composio Connector
# NOTE: Disable "Mask Connected Account Secrets" in Composio dashboard (Settings → Project Settings) for Google indexing to work.
COMPOSIO_API_KEY=your_api_key_here COMPOSIO_API_KEY=your_api_key_here
COMPOSIO_ENABLED=TRUE COMPOSIO_ENABLED=TRUE
COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback

View file

@ -25,13 +25,6 @@ database_url = os.getenv("DATABASE_URL")
if database_url: if database_url:
config.set_main_option("sqlalchemy.url", database_url) config.set_main_option("sqlalchemy.url", database_url)
# Electric SQL user credentials - centralized configuration for migrations
# These are used by migrations that set up Electric SQL replication
config.set_main_option("electric_db_user", os.getenv("ELECTRIC_DB_USER", "electric"))
config.set_main_option(
"electric_db_password", os.getenv("ELECTRIC_DB_PASSWORD", "electric_password")
)
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:

View file

@ -30,21 +30,25 @@ def upgrade() -> None:
"ix_notifications_user_read_type_created", "ix_notifications_user_read_type_created",
"notifications", "notifications",
["user_id", "read", "type", "created_at"], ["user_id", "read", "type", "created_at"],
if_not_exists=True,
) )
op.create_index( op.create_index(
"ix_notifications_user_space_created", "ix_notifications_user_space_created",
"notifications", "notifications",
["user_id", "search_space_id", "created_at"], ["user_id", "search_space_id", "created_at"],
if_not_exists=True,
) )
op.create_index( op.create_index(
"ix_notifications_type", "ix_notifications_type",
"notifications", "notifications",
["type"], ["type"],
if_not_exists=True,
) )
op.create_index( op.create_index(
"ix_notifications_search_space_id", "ix_notifications_search_space_id",
"notifications", "notifications",
["search_space_id"], ["search_space_id"],
if_not_exists=True,
) )

View file

@ -35,52 +35,60 @@ def upgrade() -> None:
END $$; END $$;
""") """)
op.create_table( conn = op.get_bind()
"video_presentations", result = conn.execute(
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.text("SELECT 1 FROM information_schema.tables WHERE table_name = 'video_presentations'")
sa.Column("title", sa.String(length=500), nullable=False),
sa.Column("slides", JSONB(), nullable=True),
sa.Column("scene_codes", JSONB(), nullable=True),
sa.Column(
"status",
video_presentation_status_enum,
server_default="ready",
nullable=False,
),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("thread_id", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["search_space_id"],
["searchspaces.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["thread_id"],
["new_chat_threads.id"],
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("id"),
) )
if not result.fetchone():
op.create_table(
"video_presentations",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("title", sa.String(length=500), nullable=False),
sa.Column("slides", JSONB(), nullable=True),
sa.Column("scene_codes", JSONB(), nullable=True),
sa.Column(
"status",
video_presentation_status_enum,
server_default="ready",
nullable=False,
),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("thread_id", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["search_space_id"],
["searchspaces.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["thread_id"],
["new_chat_threads.id"],
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index( op.create_index(
"ix_video_presentations_status", "ix_video_presentations_status",
"video_presentations", "video_presentations",
["status"], ["status"],
if_not_exists=True,
) )
op.create_index( op.create_index(
"ix_video_presentations_thread_id", "ix_video_presentations_thread_id",
"video_presentations", "video_presentations",
["thread_id"], ["thread_id"],
if_not_exists=True,
) )
op.create_index( op.create_index(
"ix_video_presentations_created_at", "ix_video_presentations_created_at",
"video_presentations", "video_presentations",
["created_at"], ["created_at"],
if_not_exists=True,
) )

View file

@ -0,0 +1,104 @@
"""Clean up Electric SQL artifacts (user, publication, replication slots)
Revision ID: 108
Revises: 107
Removes leftover Electric SQL infrastructure that is no longer needed after
the migration to Rocicorp Zero. Fully idempotent safe on databases that
never had Electric SQL set up (fresh installs).
Cleaned up:
- Replication slots containing 'electric' (prevents unbounded WAL growth)
- The 'electric_publication_default' publication
- Default privileges, grants, and the 'electric' database user
"""
from collections.abc import Sequence
from alembic import op
revision: str = "108"
down_revision: str | None = "107"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"""
DO $$
DECLARE
slot RECORD;
BEGIN
-- 1. Drop inactive Electric replication slots (prevents WAL growth)
FOR slot IN
SELECT slot_name FROM pg_replication_slots
WHERE slot_name LIKE '%electric%' AND active = false
LOOP
BEGIN
PERFORM pg_drop_replication_slot(slot.slot_name);
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not drop replication slot %: %', slot.slot_name, SQLERRM;
END;
END LOOP;
-- Warn about active Electric slots that cannot be safely dropped
FOR slot IN
SELECT slot_name FROM pg_replication_slots
WHERE slot_name LIKE '%electric%' AND active = true
LOOP
RAISE WARNING 'Active Electric replication slot "%" was not dropped — drop it manually to stop WAL growth', slot.slot_name;
END LOOP;
-- 2. Drop the Electric publication
BEGIN
IF EXISTS (SELECT 1 FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
DROP PUBLICATION electric_publication_default;
END IF;
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not drop publication electric_publication_default: %', SQLERRM;
END;
-- 3. Revoke privileges and drop the Electric user
IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'electric') THEN
BEGIN
ALTER DEFAULT PRIVILEGES IN SCHEMA public
REVOKE SELECT ON TABLES FROM electric;
ALTER DEFAULT PRIVILEGES IN SCHEMA public
REVOKE SELECT ON SEQUENCES FROM electric;
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not revoke default privileges from electric: %', SQLERRM;
END;
BEGIN
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM electric;
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM electric;
REVOKE USAGE ON SCHEMA public FROM electric;
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not revoke schema privileges from electric: %', SQLERRM;
END;
BEGIN
EXECUTE format(
'REVOKE CONNECT ON DATABASE %I FROM electric',
current_database()
);
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not revoke CONNECT from electric: %', SQLERRM;
END;
BEGIN
REASSIGN OWNED BY electric TO CURRENT_USER;
DROP ROLE electric;
EXCEPTION WHEN OTHERS THEN
RAISE WARNING 'Could not drop role electric: %', SQLERRM;
END;
END IF;
END
$$;
"""
)
def downgrade() -> None:
pass

View file

@ -5,7 +5,7 @@ This module provides the SurfSense deep agent with configurable tools
for knowledge base search, podcast generation, and more. for knowledge base search, podcast generation, and more.
Directory Structure: Directory Structure:
- tools/: All agent tools (knowledge_base, podcast, link_preview, etc.) - tools/: All agent tools (knowledge_base, podcast, generate_image, etc.)
- chat_deepagent.py: Main agent factory - chat_deepagent.py: Main agent factory
- system_prompt.py: System prompts and instructions - system_prompt.py: System prompts and instructions
- context.py: Context schema for the agent - context.py: Context schema for the agent
@ -37,9 +37,7 @@ from .tools import (
BUILTIN_TOOLS, BUILTIN_TOOLS,
ToolDefinition, ToolDefinition,
build_tools, build_tools,
create_display_image_tool,
create_generate_podcast_tool, create_generate_podcast_tool,
create_link_preview_tool,
create_scrape_webpage_tool, create_scrape_webpage_tool,
create_search_knowledge_base_tool, create_search_knowledge_base_tool,
format_documents_for_context, format_documents_for_context,
@ -63,9 +61,7 @@ __all__ = [
# LLM config # LLM config
"create_chat_litellm_from_config", "create_chat_litellm_from_config",
# Tool factories # Tool factories
"create_display_image_tool",
"create_generate_podcast_tool", "create_generate_podcast_tool",
"create_link_preview_tool",
"create_scrape_webpage_tool", "create_scrape_webpage_tool",
"create_search_knowledge_base_tool", "create_search_knowledge_base_tool",
# Agent factory # Agent factory

View file

@ -21,6 +21,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.new_chat.system_prompt import ( from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt, build_configurable_system_prompt,
build_surfsense_system_prompt, build_surfsense_system_prompt,
@ -65,10 +68,11 @@ _CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR", "BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type "CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR", "OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
# Composio connectors # Composio connectors (unified to native document types).
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR", # Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
"COMPOSIO_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR", "COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
} }
# Document types that don't come from SearchSourceConnector but should always be searchable # Document types that don't come from SearchSourceConnector but should always be searchable
@ -146,8 +150,6 @@ async def create_surfsense_deep_agent(
- search_knowledge_base: Search the user's personal knowledge base - search_knowledge_base: Search the user's personal knowledge base
- generate_podcast: Generate audio podcasts from content - generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models - generate_image: Generate images from text descriptions using AI models
- link_preview: Fetch rich previews for URLs
- display_image: Display images in chat
- scrape_webpage: Extract content from webpages - scrape_webpage: Extract content from webpages
- save_memory: Store facts/preferences about the user - save_memory: Store facts/preferences about the user
- recall_memory: Retrieve relevant user memories - recall_memory: Retrieve relevant user memories
@ -203,7 +205,7 @@ async def create_surfsense_deep_agent(
# Create agent with only specific tools # Create agent with only specific tools
agent = create_surfsense_deep_agent( agent = create_surfsense_deep_agent(
llm, search_space_id, db_session, ..., llm, search_space_id, db_session, ...,
enabled_tools=["search_knowledge_base", "link_preview"] enabled_tools=["search_knowledge_base", "scrape_webpage"]
) )
# Create agent without podcast generation # Create agent without podcast generation
@ -292,6 +294,69 @@ async def create_surfsense_deep_agent(
] ]
modified_disabled_tools.extend(linear_tools) modified_disabled_tools.extend(linear_tools)
# Disable Google Drive action tools if no Google Drive connector is configured
has_google_drive_connector = (
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
)
if not has_google_drive_connector:
google_drive_tools = [
"create_google_drive_file",
"delete_google_drive_file",
]
modified_disabled_tools.extend(google_drive_tools)
# Disable Google Calendar action tools if no Google Calendar connector is configured
has_google_calendar_connector = (
available_connectors is not None
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
)
if not has_google_calendar_connector:
calendar_tools = [
"create_calendar_event",
"update_calendar_event",
"delete_calendar_event",
]
modified_disabled_tools.extend(calendar_tools)
# Disable Gmail action tools if no Gmail connector is configured
has_gmail_connector = (
available_connectors is not None
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
)
if not has_gmail_connector:
gmail_tools = [
"create_gmail_draft",
"update_gmail_draft",
"send_gmail_email",
"trash_gmail_email",
]
modified_disabled_tools.extend(gmail_tools)
# Disable Jira action tools if no Jira connector is configured
has_jira_connector = (
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
)
if not has_jira_connector:
jira_tools = [
"create_jira_issue",
"update_jira_issue",
"delete_jira_issue",
]
modified_disabled_tools.extend(jira_tools)
# Disable Confluence action tools if no Confluence connector is configured
has_confluence_connector = (
available_connectors is not None
and "CONFLUENCE_CONNECTOR" in available_connectors
)
if not has_confluence_connector:
confluence_tools = [
"create_confluence_page",
"update_confluence_page",
"delete_confluence_page",
]
modified_disabled_tools.extend(confluence_tools)
# Build tools using the async registry (includes MCP tools) # Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter() _t0 = time.perf_counter()
tools = await build_tools_async( tools = await build_tools_async(
@ -345,6 +410,7 @@ async def create_surfsense_deep_agent(
system_prompt=system_prompt, system_prompt=system_prompt,
context_schema=SurfSenseContextSchema, context_schema=SurfSenseContextSchema,
checkpointer=checkpointer, checkpointer=checkpointer,
middleware=[DedupHITLToolCallsMiddleware()],
**deep_agent_kwargs, **deep_agent_kwargs,
) )
_perf_log.info( _perf_log.info(

View file

@ -0,0 +1,93 @@
"""Middleware that deduplicates HITL tool calls within a single LLM response.
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
"delete_calendar_event": "event_title_or_id",
"update_calendar_event": "event_title_or_id",
"trash_gmail_email": "email_subject_or_id",
"update_gmail_draft": "draft_subject_or_id",
"delete_google_drive_file": "file_name",
"delete_notion_page": "page_title",
"update_notion_page": "page_title",
"delete_linear_issue": "issue_ref",
"update_linear_issue": "issue_ref",
"update_jira_issue": "issue_title_or_key",
"delete_jira_issue": "issue_title_or_key",
"update_confluence_page": "page_title_or_id",
"delete_confluence_page": "page_title_or_id",
}
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each (tool-name, primary-arg-value)
pair is kept; subsequent duplicates are silently dropped.
"""
tools = ()
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state)
@staticmethod
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
messages = state.get("messages")
if not messages:
return None
last_msg = messages[-1]
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
return None
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
seen: set[tuple[str, str]] = set()
deduped: list[dict[str, Any]] = []
for tc in tool_calls:
name = tc.get("name", "")
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
if dedup_key_arg is not None:
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
key = (name, arg_val)
if key in seen:
logger.info(
"Dedup: dropped duplicate HITL tool call %s(%s)",
name,
arg_val,
)
continue
seen.add(key)
deduped.append(tc)
if len(deduped) == len(tool_calls):
return None
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
return {"messages": [updated_msg]}

View file

@ -184,48 +184,6 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat. - AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
""" """
_TOOL_INSTRUCTIONS["link_preview"] = """
- link_preview: Fetch metadata for a URL to display a rich preview card.
- IMPORTANT: Use this tool WHENEVER the user shares or mentions a URL/link in their message.
- This fetches the page's Open Graph metadata (title, description, thumbnail) to show a preview card.
- NOTE: This tool only fetches metadata, NOT the full page content. It cannot read the article text.
- Trigger scenarios:
* User shares a URL (e.g., "Check out https://example.com")
* User pastes a link in their message
* User asks about a URL or link
- Args:
- url: The URL to fetch metadata for (must be a valid HTTP/HTTPS URL)
- Returns: A rich preview card with title, description, thumbnail, and domain
- The preview card will automatically be displayed in the chat.
"""
_TOOL_INSTRUCTIONS["display_image"] = """
- display_image: Display an image in the chat with metadata.
- Use this tool ONLY when you have a valid public HTTP/HTTPS image URL to show.
- This displays the image with an optional title, description, and source attribution.
- Valid use cases:
* Showing an image from a URL the user explicitly mentioned in their message
* Displaying images found in scraped webpage content (from scrape_webpage tool)
* Showing a publicly accessible diagram or chart from a known URL
* Displaying an AI-generated image after calling the generate_image tool (ALWAYS required)
CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS:
When a user uploads/attaches an image file to their message:
* The image is ALREADY VISIBLE in the chat UI as a thumbnail on their message
* You do NOT have a URL for their uploaded image - only extracted text/description
* Calling display_image will FAIL and show "Image not available" error
* Simply analyze the image content and respond with your analysis - DO NOT try to display it
* The user can already see their own uploaded image - they don't need you to show it again
- Args:
- src: The URL of the image (MUST be a valid public HTTP/HTTPS URL that you know exists)
- alt: Alternative text describing the image (for accessibility)
- title: Optional title to display below the image
- description: Optional description providing context about the image
- Returns: An image card with the image, title, and description
- The image will automatically be displayed in the chat.
"""
_TOOL_INSTRUCTIONS["generate_image"] = """ _TOOL_INSTRUCTIONS["generate_image"] = """
- generate_image: Generate images from text descriptions using AI image models. - generate_image: Generate images from text descriptions using AI image models.
- Use this when the user asks you to create, generate, draw, design, or make an image. - Use this when the user asks you to create, generate, draw, design, or make an image.
@ -233,10 +191,7 @@ _TOOL_INSTRUCTIONS["generate_image"] = """
- Args: - Args:
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
- n: Number of images to generate (1-4, default: 1) - n: Number of images to generate (1-4, default: 1)
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata. - Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat.
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
to actually show the image in the chat. The generate_image tool only generates the image and returns
the URL it does NOT display anything. You must always follow up with display_image.
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
expand and improve the prompt with specific details about style, lighting, composition, and mood. expand and improve the prompt with specific details about style, lighting, composition, and mood.
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
@ -245,14 +200,11 @@ _TOOL_INSTRUCTIONS["generate_image"] = """
_TOOL_INSTRUCTIONS["scrape_webpage"] = """ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
- scrape_webpage: Scrape and extract the main content from a webpage. - scrape_webpage: Scrape and extract the main content from a webpage.
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
- IMPORTANT: This is different from link_preview:
* link_preview: Only fetches metadata (title, description, thumbnail) for display
* scrape_webpage: Actually reads the FULL page content so you can analyze/summarize it
- CRITICAL WHEN TO USE (always attempt scraping, never refuse before trying): - CRITICAL WHEN TO USE (always attempt scraping, never refuse before trying):
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
* When a URL was mentioned earlier in the conversation and the user asks for its actual content * When a URL was mentioned earlier in the conversation and the user asks for its actual content
* When link_preview or search_knowledge_base returned insufficient data and the user wants more * When search_knowledge_base returned insufficient data and the user wants more
- Trigger scenarios: - Trigger scenarios:
* "Read this article and summarize it" * "Read this article and summarize it"
* "What does this page say about X?" * "What does this page say about X?"
@ -268,9 +220,10 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
- url: The URL of the webpage to scrape (must be HTTP/HTTPS) - url: The URL of the webpage to scrape (must be HTTP/HTTPS)
- max_length: Maximum content length to return (default: 50000 chars) - max_length: Maximum content length to return (default: 50000 chars)
- Returns: The page title, description, full content (in markdown), word count, and metadata - Returns: The page title, description, full content (in markdown), word count, and metadata
- After scraping, you will have the full article text and can analyze, summarize, or answer questions about it. - After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points.
- Reference the source using markdown links [descriptive text](url) never bare URLs.
- IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`. - IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`.
* When you find relevant/important images in the scraped content, use the `display_image` tool to show them to the user. * When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: `![alt text](image_url)`.
* This makes your response more visual and engaging. * This makes your response more visual and engaging.
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
* Don't show every image - just the most relevant 1-3 images that enhance understanding. * Don't show every image - just the most relevant 1-3 images that enhance understanding.
@ -292,6 +245,8 @@ _TOOL_INSTRUCTIONS["web_search"] = """
- Args: - Args:
- query: The search query - use specific, descriptive terms - query: The search query - use specific, descriptive terms
- top_k: Number of results to retrieve (default: 10, max: 50) - top_k: Number of results to retrieve (default: 10, max: 50)
- If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content.
- When presenting results, reference sources as markdown links [descriptive text](url) never bare URLs.
""" """
# Memory tool instructions have private and shared variants. # Memory tool instructions have private and shared variants.
@ -476,32 +431,31 @@ _TOOL_EXAMPLES["generate_report"] = """
_TOOL_EXAMPLES["scrape_webpage"] = """ _TOOL_EXAMPLES["scrape_webpage"] = """
- User: "Check out https://dev.to/some-article" - User: "Check out https://dev.to/some-article"
- Call: `link_preview(url="https://dev.to/some-article")`
- Call: `scrape_webpage(url="https://dev.to/some-article")` - Call: `scrape_webpage(url="https://dev.to/some-article")`
- Then provide your analysis of the content. - Respond with a structured analysis key points, takeaways.
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends" - User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")` - Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
- Then provide a summary based on the scraped text. - Respond with a thorough summary using headings and bullet points.
- User: (after discussing https://example.com/stats) "Can you get the live data from that page?" - User: (after discussing https://example.com/stats) "Can you get the live data from that page?"
- Call: `scrape_webpage(url="https://example.com/stats")` - Call: `scrape_webpage(url="https://example.com/stats")`
- IMPORTANT: Always attempt scraping first. Never refuse before trying the tool. - IMPORTANT: Always attempt scraping first. Never refuse before trying the tool.
""" - User: "https://example.com/blog/weekend-recipes"
- Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")`
_TOOL_EXAMPLES["display_image"] = """ - When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content.
- User: "Show me this image: https://example.com/image.png"
- Call: `display_image(src="https://example.com/image.png", alt="User shared image")`
- User uploads an image file and asks: "What is this image about?"
- DO NOT call display_image! The user's uploaded image is already visible in the chat.
- Simply analyze the image content and respond directly.
""" """
_TOOL_EXAMPLES["generate_image"] = """ _TOOL_EXAMPLES["generate_image"] = """
- User: "Generate an image of a cat" - User: "Generate an image of a cat"
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` - Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")` - The generated image will automatically be displayed in the chat.
- User: "Draw me a logo for a coffee shop called Bean Dream" - User: "Draw me a logo for a coffee shop called Bean Dream"
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` - Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")` - The generated image will automatically be displayed in the chat.
- User: "Show me this image: https://example.com/image.png"
- Simply include it in your response using markdown: `![Image](https://example.com/image.png)`
- User uploads an image file and asks: "What is this image about?"
- The user's uploaded image is already visible in the chat.
- Simply analyze the image content and respond directly.
""" """
_TOOL_EXAMPLES["web_search"] = """ _TOOL_EXAMPLES["web_search"] = """
@ -522,8 +476,6 @@ _ALL_TOOL_NAMES_ORDERED = [
"generate_podcast", "generate_podcast",
"generate_video_presentation", "generate_video_presentation",
"generate_report", "generate_report",
"link_preview",
"display_image",
"generate_image", "generate_image",
"scrape_webpage", "scrape_webpage",
"save_memory", "save_memory",
@ -764,7 +716,7 @@ Do not use the sandbox for:
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox: When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")` - **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker. - **DO NOT use markdown image syntax** for files created inside the sandbox. Sandbox files are not accessible via public URLs and will show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")` - You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
- Always describe what the file contains in your response text so the user knows what they are downloading. - Always describe what the file contains in your response text so the user knows what they are downloading.
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file. - IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.

View file

@ -10,8 +10,6 @@ Available tools:
- generate_podcast: Generate audio podcasts from content - generate_podcast: Generate audio podcasts from content
- generate_video_presentation: Generate video presentations with slides and narration - generate_video_presentation: Generate video presentations with slides and narration
- generate_image: Generate images from text descriptions using AI models - generate_image: Generate images from text descriptions using AI models
- link_preview: Fetch rich previews for URLs
- display_image: Display images in chat
- scrape_webpage: Extract content from webpages - scrape_webpage: Extract content from webpages
- save_memory: Store facts/preferences about the user - save_memory: Store facts/preferences about the user
- recall_memory: Retrieve relevant user memories - recall_memory: Retrieve relevant user memories
@ -19,7 +17,6 @@ Available tools:
# Registry exports # Registry exports
# Tool factory exports (for direct use) # Tool factory exports (for direct use)
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool from .generate_image import create_generate_image_tool
from .knowledge_base import ( from .knowledge_base import (
CONNECTOR_DESCRIPTIONS, CONNECTOR_DESCRIPTIONS,
@ -27,7 +24,6 @@ from .knowledge_base import (
format_documents_for_context, format_documents_for_context,
search_knowledge_base_async, search_knowledge_base_async,
) )
from .link_preview import create_link_preview_tool
from .podcast import create_generate_podcast_tool from .podcast import create_generate_podcast_tool
from .registry import ( from .registry import (
BUILTIN_TOOLS, BUILTIN_TOOLS,
@ -50,11 +46,9 @@ __all__ = [
"ToolDefinition", "ToolDefinition",
"build_tools", "build_tools",
# Tool factories # Tool factories
"create_display_image_tool",
"create_generate_image_tool", "create_generate_image_tool",
"create_generate_podcast_tool", "create_generate_podcast_tool",
"create_generate_video_presentation_tool", "create_generate_video_presentation_tool",
"create_link_preview_tool",
"create_recall_memory_tool", "create_recall_memory_tool",
"create_save_memory_tool", "create_save_memory_tool",
"create_scrape_webpage_tool", "create_scrape_webpage_tool",

View file

@ -0,0 +1,11 @@
"""Confluence tools for creating, updating, and deleting pages."""
from .create_page import create_create_confluence_page_tool
from .delete_page import create_delete_confluence_page_tool
from .update_page import create_update_confluence_page_tool
__all__ = [
"create_create_confluence_page_tool",
"create_delete_confluence_page_tool",
"create_update_confluence_page_tool",
]

View file

@ -0,0 +1,237 @@
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_create_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def create_confluence_page(
title: str,
content: str | None = None,
space_id: str | None = None,
) -> dict[str, Any]:
"""Create a new page in Confluence.
Use this tool when the user explicitly asks to create a new Confluence page.
Args:
title: Title of the page.
content: Optional HTML/storage format content for the page body.
space_id: Optional Confluence space ID to create the page in.
Returns:
Dictionary with status, page_id, and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(f"create_confluence_page called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Confluence accounts need re-authentication.",
"connector_type": "confluence",
}
approval = interrupt(
{
"type": "confluence_page_creation",
"action": {
"tool": "create_confluence_page",
"params": {
"title": title,
"content": content,
"space_id": space_id,
"connector_id": connector_id,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The page was not created.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_title = final_params.get("title", title)
final_content = final_params.get("content", content) or ""
final_space_id = final_params.get("space_id", space_id)
final_connector_id = final_params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
return {"status": "error", "message": "Page title cannot be empty."}
if not final_space_id:
return {"status": "error", "message": "A space must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Confluence connector found.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
api_result = await client.create_page(
space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"page_id": page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the page.",
}
return create_confluence_page

View file

@ -0,0 +1,215 @@
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def delete_confluence_page(
page_title_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Confluence page.
Use this tool when the user asks to delete or remove a Confluence page.
Args:
page_title_or_id: The page title or ID to identify the page.
delete_from_kb: Whether to also remove from the knowledge base.
Returns:
Dictionary with status, message, and deleted_from_kb.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
approval = interrupt(
{
"type": "confluence_page_deletion",
"action": {
"tool": "delete_confluence_page",
"params": {
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The page was not deleted.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_page_id = final_params.get("page_id", page_id)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
await client.delete_page(final_page_id)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Confluence page '{page_title}' deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"page_id": final_page_id,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the page.",
}
return delete_confluence_page

View file

@ -0,0 +1,244 @@
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_update_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def update_confluence_page(
page_title_or_id: str,
new_title: str | None = None,
new_content: str | None = None,
) -> dict[str, Any]:
"""Update an existing Confluence page.
Use this tool when the user asks to modify or edit a Confluence page.
Args:
page_title_or_id: The page title or ID to identify the page.
new_title: Optional new title for the page.
new_content: Optional new HTML/storage format content.
Returns:
Dictionary with status and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
current_title = page_data["page_title"]
current_body = page_data.get("body", "")
current_version = page_data.get("version", 1)
document_id = page_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
approval = interrupt(
{
"type": "confluence_page_update",
"action": {
"tool": "update_confluence_page",
"params": {
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The page was not updated.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_page_id = final_params.get("page_id", page_id)
final_title = final_params.get("new_title", new_title) or current_title
final_content = final_params.get("new_content", new_content)
if final_content is None:
final_content = current_body
final_version = final_params.get("version", current_version)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_document_id = final_params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
api_result = await client.update_page(
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"page_id": final_page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the page.",
}
return update_confluence_page

View file

@ -1,111 +0,0 @@
"""
Display image tool for the SurfSense agent.
This module provides a tool for displaying images in the chat UI
with metadata like title, description, and source attribution.
"""
import hashlib
from typing import Any
from urllib.parse import urlparse
from langchain_core.tools import tool
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def generate_image_id(src: str) -> str:
"""Generate a unique ID for an image."""
hash_val = hashlib.md5(src.encode()).hexdigest()[:12]
return f"image-{hash_val}"
def create_display_image_tool():
"""
Factory function to create the display_image tool.
Returns:
A configured tool function for displaying images.
"""
@tool
async def display_image(
src: str,
alt: str = "Image",
title: str | None = None,
description: str | None = None,
) -> dict[str, Any]:
"""
Display an image in the chat with metadata.
Use this tool when you want to show an image to the user.
This displays the image with an optional title, description,
and source attribution.
Common use cases:
- Showing an image from a URL the user mentioned
- Displaying a diagram or chart you're referencing
- Showing example images when explaining concepts
Args:
src: The URL of the image to display (must be a valid HTTP/HTTPS URL)
alt: Alternative text describing the image (for accessibility)
title: Optional title to display below the image
description: Optional description providing context about the image
Returns:
A dictionary containing image metadata for the UI to render:
- id: Unique identifier for this image
- assetId: The image URL (for deduplication)
- src: The image URL
- alt: Alt text for accessibility
- title: Image title (if provided)
- description: Image description (if provided)
- domain: Source domain
"""
image_id = generate_image_id(src)
# Ensure URL has protocol
if not src.startswith(("http://", "https://")):
src = f"https://{src}"
domain = extract_domain(src)
# Determine aspect ratio based on image source
# AI-generated images should use "auto" to preserve their native ratio
is_generated = "/image-generations/" in src
if is_generated:
ratio = "auto"
domain = "ai-generated"
elif "unsplash.com" in src or "pexels.com" in src:
ratio = "16:9"
elif (
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
):
ratio = "auto"
else:
ratio = "auto"
return {
"id": image_id,
"assetId": src,
"src": src,
"alt": alt,
"title": title,
"description": description,
"domain": domain,
"ratio": ratio,
}
return display_image

View file

@ -2,8 +2,7 @@
Image generation tool for the SurfSense agent. Image generation tool for the SurfSense agent.
This module provides a tool that generates images using litellm.aimage_generation() This module provides a tool that generates images using litellm.aimage_generation()
and returns the result via the existing display_image tool format so the frontend and returns the result directly in a format the frontend Image component can render.
renders the generated image inline in the chat.
Config resolution: Config resolution:
1. Uses the search space's image_generation_config_id preference 1. Uses the search space's image_generation_config_id preference
@ -11,6 +10,7 @@ Config resolution:
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs) 3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
""" """
import hashlib
import logging import logging
from typing import Any from typing import Any
@ -222,11 +222,17 @@ def create_generate_image_tool(
else: else:
return {"error": "No displayable image data in the response"} return {"error": "No displayable image data in the response"}
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
return { return {
"id": image_id,
"assetId": image_url,
"src": image_url, "src": image_url,
"alt": revised_prompt or prompt, "alt": revised_prompt or prompt,
"title": "Generated Image", "title": "Generated Image",
"description": revised_prompt if revised_prompt != prompt else None, "description": revised_prompt if revised_prompt != prompt else None,
"domain": "ai-generated",
"ratio": "auto",
"generated": True, "generated": True,
"prompt": prompt, "prompt": prompt,
"image_count": len(images), "image_count": len(images),

View file

@ -0,0 +1,19 @@
from app.agents.new_chat.tools.gmail.create_draft import (
create_create_gmail_draft_tool,
)
from app.agents.new_chat.tools.gmail.send_email import (
create_send_gmail_email_tool,
)
from app.agents.new_chat.tools.gmail.trash_email import (
create_trash_gmail_email_tool,
)
from app.agents.new_chat.tools.gmail.update_draft import (
create_update_gmail_draft_tool,
)
__all__ = [
"create_create_gmail_draft_tool",
"create_send_gmail_email_tool",
"create_trash_gmail_email_tool",
"create_update_gmail_draft_tool",
]

View file

@ -0,0 +1,341 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_create_gmail_draft_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def create_gmail_draft(
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Create a draft email in Gmail.
Use when the user asks to draft, compose, or prepare an email without
sending it.
Args:
to: Recipient email address.
subject: Email subject line.
body: Email body content.
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- draft_id: Gmail draft ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Draft an email to alice@example.com about the meeting"
- "Compose a reply to Bob about the project update"
"""
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
)
approval = interrupt(
{
"type": "gmail_draft_creation",
"action": {
"tool": "create_gmail_draft",
"params": {
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_to = final_params.get("to", to)
final_subject = final_params.get("subject", subject)
final_body = final_params.get("body", body)
final_cc = final_params.get("cc", cc)
final_bcc = final_params.get("bcc", bcc)
final_connector_id = final_params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail draft created: id={created.get('id')}")
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
draft_message = created.get("message", {})
kb_result = await kb_service.sync_after_create(
message_id=draft_message.get("id", ""),
thread_id=draft_message.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
draft_id=created.get("id"),
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": created.get("id"),
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the draft. Please try again.",
}
return create_gmail_draft

View file

@ -0,0 +1,343 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_send_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def send_gmail_email(
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Send an email via Gmail.
Use when the user explicitly asks to send an email. This sends the
email immediately - it cannot be unsent.
Args:
to: Recipient email address.
subject: Email subject line.
body: Email body content.
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- message_id: Gmail message ID (if success)
- thread_id: Gmail thread ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Send an email to alice@example.com about the meeting"
- "Email Bob the project update"
"""
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
)
approval = interrupt(
{
"type": "gmail_email_send",
"action": {
"tool": "send_gmail_email",
"params": {
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_to = final_params.get("to", to)
final_subject = final_params.get("subject", subject)
final_body = final_params.get("body", body)
final_cc = final_params.get("cc", cc)
final_bcc = final_params.get("bcc", bcc)
final_connector_id = final_params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"message_id": sent.get("id"),
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while sending the email. Please try again.",
}
return send_gmail_email

View file

@ -0,0 +1,337 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_trash_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def trash_gmail_email(
email_subject_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Move an email or draft to trash in Gmail.
Use when the user asks to delete, remove, or trash an email or draft.
Args:
email_subject_or_id: The exact subject line or message ID of the
email to trash (as it appears in the inbox).
delete_from_kb: Whether to also remove the email from the knowledge base.
Default is False.
Set to True to remove from both Gmail and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- message_id: Gmail message ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the email subject or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry this tool.
Examples:
- "Delete the email about 'Meeting Cancelled'"
- "Trash the email from Bob about the project"
"""
logger.info(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, email_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = context["account"]["id"]
if not message_id:
return {
"status": "error",
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
logger.info(
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
)
approval = interrupt(
{
"type": "gmail_email_trash",
"action": {
"tool": "trash_gmail_email",
"params": {
"message_id": message_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
edited_action = decision.get("edited_action")
final_params: dict[str, Any] = {}
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_message_id = final_params.get("message_id", message_id)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}")
trash_result: dict[str, Any] = {
"status": "success",
"message_id": final_message_id,
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while trashing the email. Please try again.",
}
return trash_gmail_email

View file

@ -0,0 +1,438 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_update_gmail_draft_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def update_gmail_draft(
draft_subject_or_id: str,
body: str,
to: str | None = None,
subject: str | None = None,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Update an existing Gmail draft.
Use when the user asks to modify, edit, or add content to an existing
email draft. This replaces the draft content with the new version.
The user will be able to review and edit the content before it is applied.
If the user simply wants to "edit" a draft without specifying exact changes,
generate the body yourself using your best understanding of the conversation
context. The user will review and can freely edit the content in the approval
card before confirming.
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
calendar events, or any other content type.
Args:
draft_subject_or_id: The exact subject line of the draft to update
(as it appears in Gmail drafts).
body: The full updated body content for the draft. Generate this
yourself based on the user's request and conversation context.
to: Optional new recipient email address (keeps original if omitted).
subject: Optional new subject line (keeps original if omitted).
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- draft_id: Gmail draft ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the draft subject or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Update the Kurseong Plan draft with the new itinerary details"
- "Edit my draft about the project proposal and change the recipient"
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
"""
logger.info(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, draft_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
approval = interrupt(
{
"type": "gmail_draft_update",
"action": {
"tool": "update_gmail_draft",
"params": {
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_to = final_params.get("to", final_to_default)
final_subject = final_params.get("subject", final_subject_default)
final_body = final_params.get("body", body)
final_cc = final_params.get("cc", cc)
final_bcc = final_params.get("bcc", bcc)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_draft_id = final_params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this draft.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}")
kb_message_suffix = ""
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import Document
doc_result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
document.source_markdown = final_body
document.title = final_subject
meta = dict(document.document_metadata or {})
meta["subject"] = final_subject
meta["draft_id"] = updated.get("id", final_draft_id)
updated_msg = updated.get("message", {})
if updated_msg.get("id"):
meta["message_id"] = updated_msg["id"]
document.document_metadata = meta
flag_modified(document, "document_metadata")
await db_session.commit()
kb_message_suffix = (
" Your knowledge base has also been updated."
)
logger.info(
f"KB document {document_id} updated for draft {final_draft_id}"
)
else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": updated.get("id"),
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the draft. Please try again.",
}
return update_gmail_draft
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
"""Look up a draft's ID by its message ID via the Gmail API."""
try:
page_token = None
while True:
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
if page_token:
kwargs["pageToken"] = page_token
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda kwargs=kwargs: (
gmail_service.users().drafts().list(**kwargs).execute()
),
)
for draft in response.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft["id"]
page_token = response.get("nextPageToken")
if not page_token:
break
return None
except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}")
return None

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.google_calendar.create_event import (
create_create_calendar_event_tool,
)
from app.agents.new_chat.tools.google_calendar.delete_event import (
create_delete_calendar_event_tool,
)
from app.agents.new_chat.tools.google_calendar.update_event import (
create_update_calendar_event_tool,
)
__all__ = [
"create_create_calendar_event_tool",
"create_delete_calendar_event_tool",
"create_update_calendar_event_tool",
]

View file

@ -0,0 +1,352 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def create_create_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def create_calendar_event(
summary: str,
start_datetime: str,
end_datetime: str,
description: str | None = None,
location: str | None = None,
attendees: list[str] | None = None,
) -> dict[str, Any]:
"""Create a new event on Google Calendar.
Use when the user asks to schedule, create, or add a calendar event.
Ask for event details if not provided.
Args:
summary: The event title.
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
description: Optional event description.
location: Optional event location.
attendees: Optional list of attendee email addresses.
Returns:
Dictionary with:
- status: "success", "rejected", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- html_link: URL to open the event (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
Examples:
- "Schedule a meeting with John tomorrow at 10am"
- "Create a calendar event for the team standup"
"""
logger.info(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Calendar accounts have expired authentication"
)
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for creating calendar event: summary='{summary}'"
)
approval = interrupt(
{
"type": "google_calendar_event_creation",
"action": {
"tool": "create_calendar_event",
"params": {
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": description,
"location": location,
"attendees": attendees,
"timezone": context.get("timezone"),
"connector_id": None,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_summary = final_params.get("summary", summary)
final_start_datetime = final_params.get("start_datetime", start_datetime)
final_end_datetime = final_params.get("end_datetime", end_datetime)
final_description = final_params.get("description", description)
final_location = final_params.get("location", location)
final_attendees = final_params.get("attendees", attendees)
final_connector_id = final_params.get("connector_id")
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Event summary cannot be empty."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"event_id": created.get("id"),
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the event. Please try again.",
}
return create_calendar_event

View file

@ -0,0 +1,332 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def delete_calendar_event(
event_title_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Google Calendar event.
Use when the user asks to delete, remove, or cancel a calendar event.
Args:
event_title_or_id: The exact title or event ID of the event to delete.
delete_from_kb: Whether to also remove the event from the knowledge base.
Default is False.
Set to True to remove from both Google Calendar and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the event name or check if it has been indexed.
Examples:
- "Delete the team standup event"
- "Cancel my dentist appointment on Friday"
"""
logger.info(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Calendar account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
)
approval = interrupt(
{
"type": "google_calendar_event_deletion",
"action": {
"tool": "delete_calendar_event",
"params": {
"event_id": event_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
edited_action = decision.get("edited_action")
final_params: dict[str, Any] = {}
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_event_id = final_params.get("event_id", event_id)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}")
delete_result: dict[str, Any] = {
"status": "success",
"event_id": final_event_id,
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
)
return delete_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the event. Please try again.",
}
return delete_calendar_event

View file

@ -0,0 +1,382 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def create_update_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def update_calendar_event(
event_title_or_id: str,
new_summary: str | None = None,
new_start_datetime: str | None = None,
new_end_datetime: str | None = None,
new_description: str | None = None,
new_location: str | None = None,
new_attendees: list[str] | None = None,
) -> dict[str, Any]:
"""Update an existing Google Calendar event.
Use when the user asks to modify, reschedule, or change a calendar event.
Args:
event_title_or_id: The exact title or event ID of the event to update.
new_summary: New event title (if changing).
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
new_description: New event description (if changing).
new_location: New event location (if changing).
new_attendees: New list of attendee email addresses (if changing).
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- html_link: URL to open the event (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the event name or check if it has been indexed.
Examples:
- "Reschedule the team standup to 3pm"
- "Change the location of my dentist appointment"
"""
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if context.get("auth_expired"):
logger.warning("Google Calendar account has expired authentication")
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
approval = interrupt(
{
"type": "google_calendar_event_update",
"action": {
"tool": "update_calendar_event",
"params": {
"event_id": event_id,
"document_id": document_id,
"connector_id": connector_id_from_context,
"new_summary": new_summary,
"new_start_datetime": new_start_datetime,
"new_end_datetime": new_end_datetime,
"new_description": new_description,
"new_location": new_location,
"new_attendees": new_attendees,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
edited_action = decision.get("edited_action")
final_params: dict[str, Any] = {}
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_event_id = final_params.get("event_id", event_id)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_new_summary = final_params.get("new_summary", new_summary)
final_new_start_datetime = final_params.get(
"new_start_datetime", new_start_datetime
)
final_new_end_datetime = final_params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = final_params.get("new_description", new_description)
final_new_location = final_params.get("new_location", new_location)
final_new_attendees = final_params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
creds = build_composio_credentials(cca_id)
else:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
)
update_body["start"] = {
"dateTime": final_new_start_datetime,
"timeZone": tz,
}
if final_new_end_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
)
update_body["end"] = {
"dateTime": final_new_end_datetime,
"timeZone": tz,
}
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
if not update_body:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}")
kb_message_suffix = ""
if document_id is not None:
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
event_id=final_event_id,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return {
"status": "success",
"event_id": final_event_id,
"html_link": updated.get("htmlLink"),
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the event. Please try again.",
}
return update_calendar_event

View file

@ -32,13 +32,16 @@ def create_create_google_drive_file_tool(
"""Create a new Google Doc or Google Sheet in Google Drive. """Create a new Google Doc or Google Sheet in Google Drive.
Use this tool when the user explicitly asks to create a new document Use this tool when the user explicitly asks to create a new document
or spreadsheet in Google Drive. or spreadsheet in Google Drive. The user MUST specify a topic before
you call this tool. If the request does not contain a topic (e.g.
"create a drive doc" or "make a Google Sheet"), ask what the file
should be about. Never call this tool without a clear topic from the user.
Args: Args:
name: The file name (without extension). name: The file name (without extension).
file_type: Either "google_doc" or "google_sheet". file_type: Either "google_doc" or "google_sheet".
content: Optional initial content. For google_doc, provide markdown text. content: Optional initial content. Generate from the user's topic.
For google_sheet, provide CSV-formatted text. For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
Returns: Returns:
Dictionary with: Dictionary with:
@ -55,8 +58,8 @@ def create_create_google_drive_file_tool(
Inform the user they need to re-authenticate and do NOT retry the action. Inform the user they need to re-authenticate and do NOT retry the action.
Examples: Examples:
- "Create a Google Doc called 'Meeting Notes'" - "Create a Google Doc with today's meeting notes"
- "Create a spreadsheet named 'Budget 2026' with some sample data" - "Create a spreadsheet for the 2026 budget"
""" """
logger.info( logger.info(
f"create_google_drive_file called: name='{name}', type='{file_type}'" f"create_google_drive_file called: name='{name}', type='{file_type}'"
@ -84,6 +87,15 @@ def create_create_google_drive_file_tool(
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Google Drive accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
logger.info( logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'" f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
) )
@ -154,14 +166,18 @@ def create_create_google_drive_file_tool(
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None: if final_connector_id is not None:
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type.in_(_drive_types),
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
) )
) )
connector = result.scalars().first() connector = result.scalars().first()
@ -176,8 +192,7 @@ def create_create_google_drive_file_tool(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type.in_(_drive_types),
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
) )
) )
connector = result.scalars().first() connector = result.scalars().first()
@ -191,8 +206,22 @@ def create_create_google_drive_file_tool(
logger.info( logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
) )
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient( client = GoogleDriveClient(
session=db_session, connector_id=actual_connector_id session=db_session,
connector_id=actual_connector_id,
credentials=pre_built_creds,
) )
try: try:
created = await client.create_file( created = await client.create_file(
@ -206,22 +235,65 @@ def create_create_google_drive_file_tool(
logger.warning( logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}" f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
) )
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "insufficient_permissions",
"connector_id": actual_connector_id, "connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate.", "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise raise
logger.info( logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
) )
kb_message_suffix = ""
try:
from app.services.google_drive import GoogleDriveKBSyncService
kb_service = GoogleDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=mime_type,
web_view_link=created.get("webViewLink"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return { return {
"status": "success", "status": "success",
"file_id": created.get("id"), "file_id": created.get("id"),
"name": created.get("name"), "name": created.get("name"),
"web_view_link": created.get("webViewLink"), "web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.", "message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
} }
except Exception as e: except Exception as e:

View file

@ -47,7 +47,6 @@ def create_delete_google_drive_file_tool(
to verify the file name or check if it has been indexed. to verify the file name or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope. - If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry this tool. Inform the user they need to re-authenticate and do NOT retry this tool.
Examples: Examples:
- "Delete the 'Meeting Notes' file from Google Drive" - "Delete the 'Meeting Notes' file from Google Drive"
- "Trash the 'Old Budget' spreadsheet" - "Trash the 'Old Budget' spreadsheet"
@ -76,6 +75,18 @@ def create_delete_google_drive_file_tool(
logger.error(f"Failed to fetch trash context: {error_msg}") logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg} return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Drive account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
file = context["file"] file = context["file"]
file_id = file["file_id"] file_id = file["file_id"]
document_id = file.get("document_id") document_id = file.get("document_id")
@ -151,13 +162,17 @@ def create_delete_google_drive_file_tool(
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute( result = await db_session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id, SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type.in_(_drive_types),
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
) )
) )
connector = result.scalars().first() connector = result.scalars().first()
@ -170,7 +185,23 @@ def create_delete_google_drive_file_tool(
logger.info( logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
) )
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
credentials=pre_built_creds,
)
try: try:
await client.trash_file(file_id=final_file_id) await client.trash_file(file_id=final_file_id)
except HttpError as http_err: except HttpError as http_err:
@ -178,10 +209,26 @@ def create_delete_google_drive_file_tool(
logger.warning( logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}" f"Insufficient permissions for connector {connector.id}: {http_err}"
) )
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "insufficient_permissions",
"connector_id": connector.id, "connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate.", "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise raise

View file

@ -0,0 +1,11 @@
"""Jira tools for creating, updating, and deleting issues."""
from .create_issue import create_create_jira_issue_tool
from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool
__all__ = [
"create_create_jira_issue_tool",
"create_delete_jira_issue_tool",
"create_update_jira_issue_tool",
]

View file

@ -0,0 +1,242 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_create_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def create_jira_issue(
project_key: str,
summary: str,
issue_type: str = "Task",
description: str | None = None,
priority: str | None = None,
) -> dict[str, Any]:
"""Create a new issue in Jira.
Use this tool when the user explicitly asks to create a new Jira issue/ticket.
Args:
project_key: The Jira project key (e.g. "PROJ", "ENG").
summary: Short, descriptive issue title.
issue_type: Issue type (default "Task"). Others: "Bug", "Story", "Epic".
description: Optional description body for the issue.
priority: Optional priority name (e.g. "High", "Medium", "Low").
Returns:
Dictionary with status, issue_key, and message.
IMPORTANT:
- If status is "rejected", the user declined. Do NOT retry.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Jira accounts need re-authentication.",
"connector_type": "jira",
}
approval = interrupt(
{
"type": "jira_issue_creation",
"action": {
"tool": "create_jira_issue",
"params": {
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The issue was not created.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_project_key = final_params.get("project_key", project_key)
final_summary = final_params.get("summary", summary)
final_issue_type = final_params.get("issue_type", issue_type)
final_description = final_params.get("description", description)
final_priority = final_params.get("priority", priority)
final_connector_id = final_params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return {"status": "error", "message": "Issue summary cannot be empty."}
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {"status": "error", "message": "No Jira connector found."}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
jira_client = await jira_history._get_jira_client()
api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_key = api_result.get("key", "")
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
kb_message_suffix = ""
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=issue_key,
issue_identifier=issue_key,
issue_title=final_summary,
description=final_description,
state="To Do",
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_key": issue_key,
"issue_url": issue_url,
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the issue.",
}
return create_jira_issue

View file

@ -0,0 +1,209 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def delete_jira_issue(
issue_title_or_key: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Jira issue.
Use this tool when the user asks to delete or remove a Jira issue.
Args:
issue_title_or_key: The issue key (e.g. "PROJ-42") or title.
delete_from_kb: Whether to also remove from the knowledge base.
Returns:
Dictionary with status, message, and deleted_from_kb.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
approval = interrupt(
{
"type": "jira_issue_deletion",
"action": {
"tool": "delete_jira_issue",
"params": {
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The issue was not deleted.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_issue_key = final_params.get("issue_key", issue_key)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Jira issue {final_issue_key} deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"issue_key": final_issue_key,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the issue.",
}
return delete_jira_issue

View file

@ -0,0 +1,252 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_update_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
@tool
async def update_jira_issue(
issue_title_or_key: str,
new_summary: str | None = None,
new_description: str | None = None,
new_priority: str | None = None,
) -> dict[str, Any]:
"""Update an existing Jira issue.
Use this tool when the user asks to modify, edit, or update a Jira issue.
Args:
issue_title_or_key: The issue key (e.g. "PROJ-42") or title to identify the issue.
new_summary: Optional new title/summary for the issue.
new_description: Optional new description.
new_priority: Optional new priority name.
Returns:
Dictionary with status and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message and ask user to verify.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
approval = interrupt(
{
"type": "jira_issue_update",
"action": {
"tool": "update_jira_issue",
"params": {
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The issue was not updated.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_issue_key = final_params.get("issue_key", issue_key)
final_summary = final_params.get("new_summary", new_summary)
final_description = final_params.get("new_description", new_description)
final_priority = final_params.get("new_priority", new_priority)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_document_id = final_params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {}
if final_summary:
fields["summary"] = final_summary
if final_description is not None:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [{"type": "text", "text": final_description}],
}
],
}
if final_priority:
fields["priority"] = {"name": final_priority}
if not fields:
return {"status": "error", "message": "No changes specified."}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_url = (
f"{jira_history._base_url}/browse/{final_issue_key}"
if jira_history._base_url and final_issue_key
else ""
)
kb_message_suffix = ""
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_key,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"issue_key": final_issue_key,
"issue_url": issue_url,
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the issue.",
}
return update_jira_issue

View file

@ -9,6 +9,7 @@ This module provides:
""" """
import asyncio import asyncio
import contextlib
import json import json
import re import re
import time import time
@ -19,7 +20,7 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import shielded_async_session from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
@ -60,7 +61,7 @@ def _is_degenerate_query(query: str) -> bool:
async def _browse_recent_documents( async def _browse_recent_documents(
search_space_id: int, search_space_id: int,
document_type: str | None, document_type: str | list[str] | None,
top_k: int, top_k: int,
start_date: datetime | None, start_date: datetime | None,
end_date: datetime | None, end_date: datetime | None,
@ -83,14 +84,22 @@ async def _browse_recent_documents(
base_conditions = [Document.search_space_id == search_space_id] base_conditions = [Document.search_space_id == search_space_id]
if document_type is not None: if document_type is not None:
if isinstance(document_type, str): type_list = (
try: document_type if isinstance(document_type, list) else [document_type]
doc_type_enum = DocumentType[document_type] )
base_conditions.append(Document.document_type == doc_type_enum) doc_type_enums = []
except KeyError: for dt in type_list:
return [] if isinstance(dt, str):
with contextlib.suppress(KeyError):
doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if not doc_type_enums:
return []
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else: else:
base_conditions.append(Document.document_type == document_type) base_conditions.append(Document.document_type.in_(doc_type_enums))
if start_date is not None: if start_date is not None:
base_conditions.append(Document.updated_at >= start_date) base_conditions.append(Document.updated_at >= start_date)
@ -195,10 +204,6 @@ _ALL_CONNECTORS: list[str] = [
"CRAWLED_URL", "CRAWLED_URL",
"CIRCLEBACK", "CIRCLEBACK",
"OBSIDIAN_CONNECTOR", "OBSIDIAN_CONNECTOR",
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
"COMPOSIO_GMAIL_CONNECTOR",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
] ]
# Human-readable descriptions for each connector type # Human-readable descriptions for each connector type
@ -228,10 +233,6 @@ CONNECTOR_DESCRIPTIONS: dict[str, str] = {
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)", "BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items", "CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)", "OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "Google Drive files via Composio (personal cloud storage)",
"COMPOSIO_GMAIL_CONNECTOR": "Gmail emails via Composio (personal emails)",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events via Composio (personal calendar)",
} }
@ -352,6 +353,20 @@ def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS)) return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
{
"message_id",
"thread_id",
"event_id",
"calendar_id",
"google_drive_file_id",
"page_id",
"issue_id",
"connector_id",
}
)
def format_documents_for_context( def format_documents_for_context(
documents: list[dict[str, Any]], documents: list[dict[str, Any]],
*, *,
@ -480,7 +495,10 @@ def format_documents_for_context(
total_docs = len(grouped) total_docs = len(grouped)
for doc_idx, g in enumerate(grouped.values()): for doc_idx, g in enumerate(grouped.values()):
metadata_json = json.dumps(g["metadata"], ensure_ascii=False) metadata_clean = {
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
}
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
is_live_search = g["document_type"] in live_search_connectors is_live_search = g["document_type"] in live_search_connectors
doc_lines: list[str] = [ doc_lines: list[str] = [
@ -617,7 +635,12 @@ async def search_knowledge_base_async(
if available_document_types: if available_document_types:
doc_types_set = set(available_document_types) doc_types_set = set(available_document_types)
before_count = len(connectors) before_count = len(connectors)
connectors = [c for c in connectors if c in doc_types_set] connectors = [
c
for c in connectors
if c in doc_types_set
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
]
skipped = before_count - len(connectors) skipped = before_count - len(connectors)
if skipped: if skipped:
perf.info( perf.info(
@ -654,6 +677,13 @@ async def search_knowledge_base_async(
) )
browse_connectors = connectors if connectors else [None] # type: ignore[list-item] browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
expanded_browse = []
for c in browse_connectors:
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
else:
expanded_browse.append(c)
browse_results = await asyncio.gather( browse_results = await asyncio.gather(
*[ *[
_browse_recent_documents( _browse_recent_documents(
@ -663,7 +693,7 @@ async def search_knowledge_base_async(
start_date=resolved_start_date, start_date=resolved_start_date,
end_date=resolved_end_date, end_date=resolved_end_date,
) )
for c in browse_connectors for c in expanded_browse
] ]
) )
for docs in browse_results: for docs in browse_results:
@ -779,6 +809,10 @@ async def search_knowledge_base_async(
deduplicated.append(doc) deduplicated.append(doc)
# Sort by RRF score so the most relevant documents from ANY connector
# appear first, preventing budget truncation from hiding top results.
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
output_budget = _compute_tool_output_budget(max_input_tokens) output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(deduplicated, max_chars=output_budget) result = format_documents_for_context(deduplicated, max_chars=output_budget)

View file

@ -38,11 +38,13 @@ def create_create_linear_issue_tool(
"""Create a new issue in Linear. """Create a new issue in Linear.
Use this tool when the user explicitly asks to create, add, or file Use this tool when the user explicitly asks to create, add, or file
a new issue / ticket / task in Linear. a new issue / ticket / task in Linear. The user MUST describe the issue
before you call this tool. If the request is vague, ask what the issue
should be about. Never call this tool without a clear topic from the user.
Args: Args:
title: Short, descriptive issue title. title: Short, descriptive issue title. Infer from the user's request.
description: Optional markdown body for the issue. description: Optional markdown body for the issue. Generate from context.
Returns: Returns:
Dictionary with: Dictionary with:
@ -57,9 +59,9 @@ def create_create_linear_issue_tool(
and move on. Do NOT retry, troubleshoot, or suggest alternatives. and move on. Do NOT retry, troubleshoot, or suggest alternatives.
Examples: Examples:
- "Create a Linear issue titled 'Fix login bug'" - "Create a Linear issue for the login bug"
- "Add a ticket for the payment timeout problem" - "File a ticket about the payment timeout problem"
- "File an issue about the broken search feature" - "Add an issue for the broken search feature"
""" """
logger.info(f"create_linear_issue called: title='{title}'") logger.info(f"create_linear_issue called: title='{title}'")
@ -82,6 +84,15 @@ def create_create_linear_issue_tool(
logger.error(f"Failed to fetch creation context: {context['error']}") logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]} return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'") logger.info(f"Requesting approval for creating Linear issue: '{title}'")
approval = interrupt( approval = interrupt(
{ {
@ -215,12 +226,36 @@ def create_create_linear_issue_tool(
logger.info( logger.info(
f"Linear issue created: {result.get('identifier')} - {result.get('title')}" f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
) )
kb_message_suffix = ""
try:
from app.services.linear import LinearKBSyncService
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=result.get("id"),
issue_identifier=result.get("identifier", ""),
issue_title=result.get("title", final_title),
issue_url=result.get("url"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return { return {
"status": "success", "status": "success",
"issue_id": result.get("id"), "issue_id": result.get("id"),
"identifier": result.get("identifier"), "identifier": result.get("identifier"),
"url": result.get("url"), "url": result.get("url"),
"message": result.get("message"), "message": (result.get("message", "") + kb_message_suffix),
} }
except Exception as e: except Exception as e:

View file

@ -64,7 +64,6 @@ def create_delete_linear_issue_tool(
- If status is "not_found", inform the user conversationally using the exact message - If status is "not_found", inform the user conversationally using the exact message
provided. Do NOT treat this as an error. Simply relay the message and ask the user provided. Do NOT treat this as an error. Simply relay the message and ask the user
to verify the issue title or identifier, or check if it has been indexed. to verify the issue title or identifier, or check if it has been indexed.
Examples: Examples:
- "Delete the 'Fix login bug' Linear issue" - "Delete the 'Fix login bug' Linear issue"
- "Archive ENG-42" - "Archive ENG-42"
@ -91,6 +90,14 @@ def create_delete_linear_issue_tool(
if "error" in context: if "error" in context:
error_msg = context["error"] error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower(): if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}") logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg} return {"status": "not_found", "message": error_msg}

View file

@ -103,6 +103,14 @@ def create_update_linear_issue_tool(
if "error" in context: if "error" in context:
error_msg = context["error"] error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for update context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower(): if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}") logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg} return {"status": "not_found", "message": error_msg}

View file

@ -1,465 +0,0 @@
"""
Link preview tool for the SurfSense agent.
This module provides a tool for fetching URL metadata (title, description,
Open Graph image, etc.) to display rich link previews in the chat UI.
"""
import asyncio
import hashlib
import logging
import re
from typing import Any
from urllib.parse import urlparse
import httpx
import trafilatura
from fake_useragent import UserAgent
from langchain_core.tools import tool
from playwright.sync_api import sync_playwright
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
logger = logging.getLogger(__name__)
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def extract_og_content(html: str, property_name: str) -> str | None:
"""Extract Open Graph meta content from HTML."""
# Try og:property first
pattern = rf'<meta[^>]+property=["\']og:{property_name}["\'][^>]+content=["\']([^"\']+)["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before property
pattern = rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+property=["\']og:{property_name}["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_twitter_content(html: str, name: str) -> str | None:
"""Extract Twitter Card meta content from HTML."""
pattern = (
rf'<meta[^>]+name=["\']twitter:{name}["\'][^>]+content=["\']([^"\']+)["\']'
)
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before name
pattern = (
rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']twitter:{name}["\']'
)
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_meta_description(html: str) -> str | None:
"""Extract meta description from HTML."""
pattern = r'<meta[^>]+name=["\']description["\'][^>]+content=["\']([^"\']+)["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before name
pattern = r'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']description["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_title(html: str) -> str | None:
"""Extract title from HTML."""
# Try og:title first
og_title = extract_og_content(html, "title")
if og_title:
return og_title
# Try twitter:title
twitter_title = extract_twitter_content(html, "title")
if twitter_title:
return twitter_title
# Fall back to <title> tag
pattern = r"<title[^>]*>([^<]+)</title>"
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1).strip()
return None
def extract_description(html: str) -> str | None:
"""Extract description from HTML."""
# Try og:description first
og_desc = extract_og_content(html, "description")
if og_desc:
return og_desc
# Try twitter:description
twitter_desc = extract_twitter_content(html, "description")
if twitter_desc:
return twitter_desc
# Fall back to meta description
return extract_meta_description(html)
def extract_image(html: str) -> str | None:
"""Extract image URL from HTML."""
# Try og:image first
og_image = extract_og_content(html, "image")
if og_image:
return og_image
# Try twitter:image
twitter_image = extract_twitter_content(html, "image")
if twitter_image:
return twitter_image
return None
def generate_preview_id(url: str) -> str:
"""Generate a unique ID for a link preview."""
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
return f"link-preview-{hash_val}"
def _unescape_html(text: str) -> str:
"""Unescape common HTML entities."""
return (
text.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&quot;", '"')
.replace("&#39;", "'")
.replace("&apos;", "'")
)
def _make_absolute_url(image_url: str, base_url: str) -> str:
"""Convert a relative image URL to an absolute URL."""
if image_url.startswith(("http://", "https://")):
return image_url
if image_url.startswith("//"):
return f"https:{image_url}"
if image_url.startswith("/"):
parsed = urlparse(base_url)
return f"{parsed.scheme}://{parsed.netloc}{image_url}"
return image_url
async def fetch_with_chromium(url: str) -> dict[str, Any] | None:
"""
Fetch page content using headless Chromium browser via Playwright.
Used as a fallback when simple HTTP requests are blocked (403, etc.).
Runs the sync Playwright API in a thread so it works on any event
loop, including Windows ``SelectorEventLoop``.
Args:
url: URL to fetch
Returns:
Dict with title, description, image, and raw_html, or None if failed
"""
try:
return await asyncio.to_thread(_fetch_with_chromium_sync, url)
except Exception as e:
logger.error(f"[link_preview] Chromium fallback failed for {url}: {e}")
return None
def _fetch_with_chromium_sync(url: str) -> dict[str, Any] | None:
"""Synchronous Playwright fetch executed in a worker thread."""
logger.info(f"[link_preview] Falling back to Chromium for {url}")
ua = UserAgent()
user_agent = ua.random
playwright_proxy = get_playwright_proxy()
with sync_playwright() as p:
launch_kwargs: dict = {"headless": True}
if playwright_proxy:
launch_kwargs["proxy"] = playwright_proxy
browser = p.chromium.launch(**launch_kwargs)
context = browser.new_context(user_agent=user_agent)
page = context.new_page()
try:
page.goto(url, wait_until="domcontentloaded", timeout=30000)
raw_html = page.content()
finally:
browser.close()
if not raw_html or len(raw_html.strip()) == 0:
logger.warning(f"[link_preview] Chromium returned empty content for {url}")
return None
trafilatura_metadata = trafilatura.extract_metadata(raw_html)
image = extract_image(raw_html)
result: dict[str, Any] = {
"title": None,
"description": None,
"image": image,
"raw_html": raw_html,
}
if trafilatura_metadata:
result["title"] = trafilatura_metadata.title
result["description"] = trafilatura_metadata.description
if not result["title"]:
result["title"] = extract_title(raw_html)
if not result["description"]:
result["description"] = extract_description(raw_html)
logger.info(f"[link_preview] Successfully fetched {url} via Chromium")
return result
def create_link_preview_tool():
"""
Factory function to create the link_preview tool.
Returns:
A configured tool function for fetching link previews.
"""
@tool
async def link_preview(url: str) -> dict[str, Any]:
"""
Fetch metadata for a URL to display a rich link preview.
Use this tool when the user shares a URL or asks about a specific webpage.
This tool fetches the page's Open Graph metadata (title, description, image)
to display a nice preview card in the chat.
Common triggers include:
- User shares a URL in the chat
- User asks "What's this link about?" or similar
- User says "Show me a preview of this page"
- User wants to preview an article or webpage
Args:
url: The URL to fetch metadata for. Must be a valid HTTP/HTTPS URL.
Returns:
A dictionary containing:
- id: Unique identifier for this preview
- assetId: The URL itself (for deduplication)
- kind: "link" (type of media card)
- href: The URL to open when clicked
- title: Page title
- description: Page description (if available)
- thumb: Thumbnail/preview image URL (if available)
- domain: The domain name
- error: Error message (if fetch failed)
"""
preview_id = generate_preview_id(url)
domain = extract_domain(url)
# Validate URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
# Generate a random User-Agent to avoid bot detection
ua = UserAgent()
user_agent = ua.random
# Use residential proxy if configured
proxy_url = get_residential_proxy_url()
# Use a browser-like User-Agent to fetch Open Graph metadata.
# We're only fetching publicly available metadata (title, description, thumbnail)
# that websites intentionally expose via OG tags for link preview purposes.
async with httpx.AsyncClient(
timeout=10.0,
follow_redirects=True,
proxy=proxy_url,
headers={
"User-Agent": user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.9",
"Accept-Encoding": "gzip, deflate, br",
"Cache-Control": "no-cache",
"Pragma": "no-cache",
},
) as client:
response = await client.get(url)
response.raise_for_status()
# Get content type to ensure it's HTML
content_type = response.headers.get("content-type", "")
if "text/html" not in content_type.lower():
# Not an HTML page, return basic info
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": url.split("/")[-1] or domain,
"description": f"File from {domain}",
"domain": domain,
}
html = response.text
# Extract metadata
title = extract_title(html) or domain
description = extract_description(html)
image = extract_image(html)
# Make sure image URL is absolute
if image:
image = _make_absolute_url(image, url)
# Clean up title and description (unescape HTML entities)
if title:
title = _unescape_html(title)
if description:
description = _unescape_html(description)
# Truncate long descriptions
if len(description) > 200:
description = description[:197] + "..."
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": title,
"description": description,
"thumb": image,
"domain": domain,
}
except httpx.TimeoutException:
# Timeout - try Chromium fallback
logger.warning(
f"[link_preview] Timeout for {url}, trying Chromium fallback"
)
chromium_result = await fetch_with_chromium(url)
if chromium_result:
title = chromium_result.get("title") or domain
description = chromium_result.get("description")
image = chromium_result.get("image")
# Clean up and truncate
if title:
title = _unescape_html(title)
if description:
description = _unescape_html(description)
if len(description) > 200:
description = description[:197] + "..."
# Make sure image URL is absolute
if image:
image = _make_absolute_url(image, url)
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": title,
"description": description,
"thumb": image,
"domain": domain,
}
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": "Request timed out",
}
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
# For 403 (Forbidden) and similar bot-detection errors, try Chromium fallback
if status_code in (403, 401, 406, 429):
logger.warning(
f"[link_preview] HTTP {status_code} for {url}, trying Chromium fallback"
)
chromium_result = await fetch_with_chromium(url)
if chromium_result:
title = chromium_result.get("title") or domain
description = chromium_result.get("description")
image = chromium_result.get("image")
# Clean up and truncate
if title:
title = _unescape_html(title)
if description:
description = _unescape_html(description)
if len(description) > 200:
description = description[:197] + "..."
# Make sure image URL is absolute
if image:
image = _make_absolute_url(image, url)
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": title,
"description": description,
"thumb": image,
"domain": domain,
}
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": f"HTTP {status_code}",
}
except Exception as e:
error_message = str(e)
logger.error(f"[link_preview] Error fetching {url}: {error_message}")
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": f"Failed to fetch: {error_message[:50]}",
}
return link_preview

View file

@ -33,17 +33,21 @@ def create_create_notion_page_tool(
@tool @tool
async def create_notion_page( async def create_notion_page(
title: str, title: str,
content: str, content: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create a new page in Notion with the given title and content. """Create a new page in Notion with the given title and content.
Use this tool when the user asks you to create, save, or publish Use this tool when the user asks you to create, save, or publish
something to Notion. The page will be created in the user's something to Notion. The page will be created in the user's
configured Notion workspace. configured Notion workspace. The user MUST specify a topic before you
call this tool. If the request does not contain a topic (e.g. "create a
notion page"), ask what the page should be about. Never call this tool
without a clear topic from the user.
Args: Args:
title: The title of the Notion page. title: The title of the Notion page.
content: The markdown content for the page body (supports headings, lists, paragraphs). content: Optional markdown content for the page body (supports headings, lists, paragraphs).
Generate this yourself based on the user's topic.
Returns: Returns:
Dictionary with: Dictionary with:
@ -58,8 +62,8 @@ def create_create_notion_page_tool(
and move on. Do NOT troubleshoot or suggest alternatives. and move on. Do NOT troubleshoot or suggest alternatives.
Examples: Examples:
- "Create a Notion page titled 'Meeting Notes' with content 'Discussed project timeline'" - "Create a Notion page about our Q2 roadmap"
- "Save this to Notion with title 'Research Summary'" - "Save a summary of today's discussion to Notion"
""" """
logger.info(f"create_notion_page called: title='{title}'") logger.info(f"create_notion_page called: title='{title}'")
@ -85,6 +89,15 @@ def create_create_notion_page_tool(
"message": context["error"], "message": context["error"],
} }
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Notion accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
logger.info(f"Requesting approval for creating Notion page: '{title}'") logger.info(f"Requesting approval for creating Notion page: '{title}'")
approval = interrupt( approval = interrupt(
{ {
@ -215,6 +228,34 @@ def create_create_notion_page_tool(
logger.info( logger.info(
f"create_page result: {result.get('status')} - {result.get('message', '')}" f"create_page result: {result.get('status')} - {result.get('message', '')}"
) )
if result.get("status") == "success":
kb_message_suffix = ""
try:
from app.services.notion import NotionKBSyncService
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=result.get("page_id"),
page_title=result.get("title", final_title),
page_url=result.get("url"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
result["message"] = result.get("message", "") + kb_message_suffix
return result return result
except Exception as e: except Exception as e:

View file

@ -95,8 +95,19 @@ def create_delete_notion_page_tool(
"message": error_msg, "message": error_msg,
} }
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id") page_id = context.get("page_id")
connector_id_from_context = context.get("account", {}).get("id") connector_id_from_context = account.get("id")
document_id = context.get("document_id") document_id = context.get("document_id")
logger.info( logger.info(
@ -262,6 +273,18 @@ def create_delete_notion_page_tool(
raise raise
logger.error(f"Error deleting Notion page: {e}", exc_info=True) logger.error(f"Error deleting Notion page: {e}", exc_info=True)
error_str = str(e).lower()
if isinstance(e, NotionAPIError) and (
"401" in error_str or "unauthorized" in error_str
):
return {
"status": "auth_error",
"message": str(e),
"connector_id": connector_id_from_context
if "connector_id_from_context" in dir()
else None,
"connector_type": "notion",
}
if isinstance(e, ValueError | NotionAPIError): if isinstance(e, ValueError | NotionAPIError):
message = str(e) message = str(e)
else: else:

View file

@ -33,16 +33,19 @@ def create_update_notion_page_tool(
@tool @tool
async def update_notion_page( async def update_notion_page(
page_title: str, page_title: str,
content: str, content: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Update an existing Notion page by appending new content. """Update an existing Notion page by appending new content.
Use this tool when the user asks you to add content to, modify, or update Use this tool when the user asks you to add content to, modify, or update
a Notion page. The new content will be appended to the existing page content. a Notion page. The new content will be appended to the existing page content.
The user MUST specify what to add before you call this tool. If the
request is vague, ask what content they want added.
Args: Args:
page_title: The title of the Notion page to update. page_title: The title of the Notion page to update.
content: The markdown content to append to the page body (supports headings, lists, paragraphs). content: Optional markdown content to append to the page body (supports headings, lists, paragraphs).
Generate this yourself based on the user's request.
Returns: Returns:
Dictionary with: Dictionary with:
@ -60,10 +63,9 @@ def create_update_notion_page_tool(
Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]" Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]"
Do NOT treat this as an error. Do NOT invent information. Simply relay the message and Do NOT treat this as an error. Do NOT invent information. Simply relay the message and
ask the user to verify the page title or check if it's been indexed. ask the user to verify the page title or check if it's been indexed.
Examples: Examples:
- "Add 'New meeting notes from today' to the 'Meeting Notes' Notion page" - "Add today's meeting notes to the 'Meeting Notes' Notion page"
- "Append the following to the 'Project Plan' Notion page: '# Status Update\n\nCompleted phase 1'" - "Update the 'Project Plan' page with a status update on phase 1"
""" """
logger.info( logger.info(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}" f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
@ -107,6 +109,17 @@ def create_update_notion_page_tool(
"message": error_msg, "message": error_msg,
} }
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id") page_id = context.get("page_id")
document_id = context.get("document_id") document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id") connector_id_from_context = context.get("account", {}).get("id")
@ -261,6 +274,18 @@ def create_update_notion_page_tool(
raise raise
logger.error(f"Error updating Notion page: {e}", exc_info=True) logger.error(f"Error updating Notion page: {e}", exc_info=True)
error_str = str(e).lower()
if isinstance(e, NotionAPIError) and (
"401" in error_str or "unauthorized" in error_str
):
return {
"status": "auth_error",
"message": str(e),
"connector_id": connector_id_from_context
if "connector_id_from_context" in dir()
else None,
"connector_type": "notion",
}
if isinstance(e, ValueError | NotionAPIError): if isinstance(e, ValueError | NotionAPIError):
message = str(e) message = str(e)
else: else:

View file

@ -45,19 +45,38 @@ from langchain_core.tools import BaseTool
from app.db import ChatVisibility from app.db import ChatVisibility
from .display_image import create_display_image_tool from .confluence import (
create_create_confluence_page_tool,
create_delete_confluence_page_tool,
create_update_confluence_page_tool,
)
from .generate_image import create_generate_image_tool from .generate_image import create_generate_image_tool
from .gmail import (
create_create_gmail_draft_tool,
create_send_gmail_email_tool,
create_trash_gmail_email_tool,
create_update_gmail_draft_tool,
)
from .google_calendar import (
create_create_calendar_event_tool,
create_delete_calendar_event_tool,
create_update_calendar_event_tool,
)
from .google_drive import ( from .google_drive import (
create_create_google_drive_file_tool, create_create_google_drive_file_tool,
create_delete_google_drive_file_tool, create_delete_google_drive_file_tool,
) )
from .jira import (
create_create_jira_issue_tool,
create_delete_jira_issue_tool,
create_update_jira_issue_tool,
)
from .knowledge_base import create_search_knowledge_base_tool from .knowledge_base import create_search_knowledge_base_tool
from .linear import ( from .linear import (
create_create_linear_issue_tool, create_create_linear_issue_tool,
create_delete_linear_issue_tool, create_delete_linear_issue_tool,
create_update_linear_issue_tool, create_update_linear_issue_tool,
) )
from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools from .mcp_tool import load_mcp_tools
from .notion import ( from .notion import (
create_create_notion_page_tool, create_create_notion_page_tool,
@ -166,20 +185,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
# are optional — when missing, source_strategy="kb_search" degrades # are optional — when missing, source_strategy="kb_search" degrades
# gracefully to "provided" # gracefully to "provided"
), ),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(
name="link_preview",
description="Fetch metadata for a URL to display a rich preview card",
factory=lambda deps: create_link_preview_tool(),
requires=[],
),
# Display image tool - shows images in the chat
ToolDefinition(
name="display_image",
description="Display an image in the chat with metadata",
factory=lambda deps: create_display_image_tool(),
requires=[],
),
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.) # Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
ToolDefinition( ToolDefinition(
name="generate_image", name="generate_image",
@ -257,7 +262,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["user_id", "search_space_id", "db_session", "thread_visibility"], requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
), ),
# ========================================================================= # =========================================================================
# LINEAR TOOLS - create, update, delete issues (WIP - hidden from UI) # LINEAR TOOLS - create, update, delete issues
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
# ========================================================================= # =========================================================================
ToolDefinition( ToolDefinition(
name="create_linear_issue", name="create_linear_issue",
@ -268,8 +274,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
ToolDefinition( ToolDefinition(
name="update_linear_issue", name="update_linear_issue",
@ -280,8 +284,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
ToolDefinition( ToolDefinition(
name="delete_linear_issue", name="delete_linear_issue",
@ -292,11 +294,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
# ========================================================================= # =========================================================================
# NOTION TOOLS - create, update, delete pages (WIP - hidden from UI) # NOTION TOOLS - create, update, delete pages
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
# ========================================================================= # =========================================================================
ToolDefinition( ToolDefinition(
name="create_notion_page", name="create_notion_page",
@ -307,8 +308,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
ToolDefinition( ToolDefinition(
name="update_notion_page", name="update_notion_page",
@ -319,8 +318,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
ToolDefinition( ToolDefinition(
name="delete_notion_page", name="delete_notion_page",
@ -331,11 +328,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
# ========================================================================= # =========================================================================
# GOOGLE DRIVE TOOLS - create files, delete files (WIP - hidden from UI) # GOOGLE DRIVE TOOLS - create files, delete files
# Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py)
# ========================================================================= # =========================================================================
ToolDefinition( ToolDefinition(
name="create_google_drive_file", name="create_google_drive_file",
@ -346,8 +342,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False,
hidden=True,
), ),
ToolDefinition( ToolDefinition(
name="delete_google_drive_file", name="delete_google_drive_file",
@ -358,8 +352,152 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"], user_id=deps["user_id"],
), ),
requires=["db_session", "search_space_id", "user_id"], requires=["db_session", "search_space_id", "user_id"],
enabled_by_default=False, ),
hidden=True, # =========================================================================
# GOOGLE CALENDAR TOOLS - create, update, delete events
# Auto-disabled when no Google Calendar connector is configured
# =========================================================================
ToolDefinition(
name="create_calendar_event",
description="Create a new event on Google Calendar",
factory=lambda deps: create_create_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_calendar_event",
description="Update an existing indexed Google Calendar event",
factory=lambda deps: create_update_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_calendar_event",
description="Delete an existing indexed Google Calendar event",
factory=lambda deps: create_delete_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
# Auto-disabled when no Gmail connector is configured
# =========================================================================
ToolDefinition(
name="create_gmail_draft",
description="Create a draft email in Gmail",
factory=lambda deps: create_create_gmail_draft_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="send_gmail_email",
description="Send an email via Gmail",
factory=lambda deps: create_send_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="trash_gmail_email",
description="Move an indexed email to trash in Gmail",
factory=lambda deps: create_trash_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_gmail_draft",
description="Update an existing Gmail draft",
factory=lambda deps: create_update_gmail_draft_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# JIRA TOOLS - create, update, delete issues
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_jira_issue",
description="Create a new issue in the user's Jira project",
factory=lambda deps: create_create_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_jira_issue",
description="Update an existing indexed Jira issue",
factory=lambda deps: create_update_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_jira_issue",
description="Delete an existing indexed Jira issue",
factory=lambda deps: create_delete_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# CONFLUENCE TOOLS - create, update, delete pages
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_confluence_page",
description="Create a new page in the user's Confluence space",
factory=lambda deps: create_create_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_confluence_page",
description="Update an existing indexed Confluence page",
factory=lambda deps: create_update_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_confluence_page",
description="Delete an existing indexed Confluence page",
factory=lambda deps: create_delete_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
), ),
] ]
@ -413,7 +551,7 @@ def build_tools(
tools = build_tools(deps) tools = build_tools(deps)
# Use only specific tools # Use only specific tools
tools = build_tools(deps, enabled_tools=["search_knowledge_base", "link_preview"]) tools = build_tools(deps, enabled_tools=["search_knowledge_base"])
# Use defaults but disable podcast # Use defaults but disable podcast
tools = build_tools(deps, disabled_tools=["generate_podcast"]) tools = build_tools(deps, disabled_tools=["generate_podcast"])

View file

@ -1,719 +0,0 @@
"""
Composio Gmail Connector Module.
Provides Gmail specific methods for data retrieval and indexing via Composio.
"""
import logging
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import Any
from bs4 import BeautifulSoup
from markdownify import markdownify as md
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
calculate_date_range,
check_duplicate_document_by_hash,
safe_set_chunks,
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
# Heartbeat configuration
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__)
def get_current_timestamp() -> datetime:
"""Get the current timestamp with timezone for updated_at field."""
return datetime.now(UTC)
async def check_document_by_unique_identifier(
session: AsyncSession, unique_identifier_hash: str
) -> Document | None:
"""Check if a document with the given unique identifier hash already exists."""
existing_doc_result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
.where(Document.unique_identifier_hash == unique_identifier_hash)
)
return existing_doc_result.scalars().first()
async def update_connector_last_indexed(
session: AsyncSession,
connector,
update_last_indexed: bool = True,
) -> None:
"""Update the last_indexed_at timestamp for a connector."""
if update_last_indexed:
connector.last_indexed_at = datetime.now(UTC)
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
class ComposioGmailConnector(ComposioConnector):
"""
Gmail specific Composio connector.
Provides methods for listing messages, getting message details, and formatting
Gmail messages from Gmail via Composio.
"""
async def list_gmail_messages(
self,
query: str = "",
max_results: int = 50,
page_token: str | None = None,
) -> tuple[list[dict[str, Any]], str | None, int | None, str | None]:
"""
List Gmail messages via Composio with pagination support.
Args:
query: Gmail search query.
max_results: Maximum number of messages per page (default: 50).
page_token: Optional pagination token for next page.
Returns:
Tuple of (messages list, next_page_token, result_size_estimate, error message).
"""
connected_account_id = await self.get_connected_account_id()
if not connected_account_id:
return [], None, None, "No connected account ID found"
entity_id = await self.get_entity_id()
service = await self._get_service()
return await service.get_gmail_messages(
connected_account_id=connected_account_id,
entity_id=entity_id,
query=query,
max_results=max_results,
page_token=page_token,
)
async def get_gmail_message_detail(
self, message_id: str
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get full details of a Gmail message via Composio.
Args:
message_id: Gmail message ID.
Returns:
Tuple of (message details, error message).
"""
connected_account_id = await self.get_connected_account_id()
if not connected_account_id:
return None, "No connected account ID found"
entity_id = await self.get_entity_id()
service = await self._get_service()
return await service.get_gmail_message_detail(
connected_account_id=connected_account_id,
entity_id=entity_id,
message_id=message_id,
)
@staticmethod
def _html_to_markdown(html: str) -> str:
"""Convert HTML (especially email layouts with nested tables) to clean markdown."""
soup = BeautifulSoup(html, "html.parser")
for tag in soup.find_all(["style", "script", "img"]):
tag.decompose()
for tag in soup.find_all(
["table", "thead", "tbody", "tfoot", "tr", "td", "th"]
):
tag.unwrap()
return md(str(soup)).strip()
def format_gmail_message_to_markdown(self, message: dict[str, Any]) -> str:
"""
Format a Gmail message to markdown.
Args:
message: Message object from Composio's GMAIL_FETCH_EMAILS response.
Composio structure: messageId, messageText, messageTimestamp,
payload.headers, labelIds, attachmentList
Returns:
Formatted markdown string.
"""
try:
# Composio uses 'messageId' (camelCase)
message_id = message.get("messageId", "") or message.get("id", "")
label_ids = message.get("labelIds", [])
# Extract headers from payload
payload = message.get("payload", {})
headers = payload.get("headers", [])
# Parse headers into a dict
header_dict = {}
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
header_dict[name] = value
# Extract key information
subject = header_dict.get("subject", "No Subject")
from_email = header_dict.get("from", "Unknown Sender")
to_email = header_dict.get("to", "Unknown Recipient")
# Composio provides messageTimestamp directly
date_str = message.get("messageTimestamp", "") or header_dict.get(
"date", "Unknown Date"
)
# Build markdown content
markdown_content = f"# {subject}\n\n"
markdown_content += f"**From:** {from_email}\n"
markdown_content += f"**To:** {to_email}\n"
markdown_content += f"**Date:** {date_str}\n"
if label_ids:
markdown_content += f"**Labels:** {', '.join(label_ids)}\n"
markdown_content += "\n---\n\n"
# Composio provides full message text in 'messageText' which is often raw HTML
message_text = message.get("messageText", "")
if message_text:
message_text = self._html_to_markdown(message_text)
markdown_content += f"## Content\n\n{message_text}\n\n"
else:
# Fallback to snippet if no messageText
snippet = message.get("snippet", "")
if snippet:
markdown_content += f"## Preview\n\n{snippet}\n\n"
# Add attachment info if present
attachments = message.get("attachmentList", [])
if attachments:
markdown_content += "## Attachments\n\n"
for att in attachments:
att_name = att.get("filename", att.get("name", "Unknown"))
markdown_content += f"- {att_name}\n"
markdown_content += "\n"
# Add message metadata
markdown_content += "## Message Details\n\n"
markdown_content += f"- **Message ID:** {message_id}\n"
return markdown_content
except Exception as e:
return f"Error formatting message to markdown: {e!s}"
# ============ Indexer Functions ============
async def _analyze_gmail_messages_phase1(
session: AsyncSession,
messages: list[dict[str, Any]],
composio_connector: ComposioGmailConnector,
connector_id: int,
search_space_id: int,
user_id: str,
) -> tuple[list[dict[str, Any]], int, int]:
"""
Phase 1: Analyze all messages, create pending documents.
Makes ALL documents visible in the UI immediately with pending status.
Returns:
Tuple of (messages_to_process, documents_skipped, duplicate_content_count)
"""
messages_to_process = []
documents_skipped = 0
duplicate_content_count = 0
for message in messages:
try:
# Composio uses 'messageId' (camelCase), not 'id'
message_id = message.get("messageId", "") or message.get("id", "")
if not message_id:
documents_skipped += 1
continue
# Extract message info from Composio response
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = message.get("messageTimestamp", "Unknown Date")
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
# Format to markdown using the full message data
markdown_content = composio_connector.format_gmail_message_to_markdown(
message
)
# Check for empty content
if not markdown_content.strip():
logger.warning(f"Skipping Gmail message with no content: {subject}")
documents_skipped += 1
continue
# Generate unique identifier
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"])
unique_identifier_hash = generate_unique_identifier_hash(
document_type, f"gmail_{message_id}", search_space_id
)
content_hash = generate_content_hash(markdown_content, search_space_id)
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Get label IDs and thread_id from Composio response
label_ids = message.get("labelIds", [])
thread_id = message.get("threadId", "") or message.get("thread_id", "")
if existing_document:
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
"label_ids": label_ids,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from standard connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Message {subject} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=subject,
document_type=DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"]),
document_metadata={
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"labels": label_ids,
"connector_id": connector_id,
"toolkit_id": "gmail",
"source": "composio",
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
messages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
"label_ids": label_ids,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
documents_skipped += 1
continue
return messages_to_process, documents_skipped, duplicate_content_count
async def _process_gmail_messages_phase2(
session: AsyncSession,
messages_to_process: list[dict[str, Any]],
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int]:
"""
Phase 2: Process each document one by one.
Each document transitions: pending processing ready/failed
Returns:
Tuple of (documents_indexed, documents_failed)
"""
documents_indexed = 0
documents_failed = 0
last_heartbeat_time = time.time()
for item in messages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"document_type": "Gmail Message (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
item["markdown_content"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["subject"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"labels": item["label_ids"],
"connector_id": connector_id,
"source": "composio",
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Gmail messages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
return documents_indexed, documents_failed
async def index_composio_gmail(
session: AsyncSession,
connector,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str | None,
end_date: str | None,
task_logger: TaskLoggingService,
log_entry,
update_last_indexed: bool = True,
max_items: int = 1000,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str]:
"""Index Gmail messages via Composio with real-time document status updates."""
try:
composio_connector = ComposioGmailConnector(session, connector_id)
# Normalize date values - handle "undefined" strings from frontend
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
if start_date is not None and end_date is not None:
start_date_str = start_date
end_date_str = end_date
else:
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Build query with date range
query_parts = []
if start_date_str:
query_parts.append(f"after:{start_date_str.replace('-', '/')}")
if end_date_str:
query_parts.append(f"before:{end_date_str.replace('-', '/')}")
query = " ".join(query_parts) if query_parts else ""
logger.info(
f"Gmail query for connector {connector_id}: '{query}' "
f"(start_date={start_date_str}, end_date={end_date_str})"
)
await task_logger.log_task_progress(
log_entry,
f"Fetching Gmail messages via Composio for connector {connector_id}",
{"stage": "fetching_messages"},
)
# =======================================================================
# FETCH ALL MESSAGES FIRST
# =======================================================================
batch_size = 50
page_token = None
all_messages = []
result_size_estimate = None
last_heartbeat_time = time.time()
while len(all_messages) < max_items:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(len(all_messages))
last_heartbeat_time = current_time
remaining = max_items - len(all_messages)
current_batch_size = min(batch_size, remaining)
(
messages,
next_token,
result_size_estimate_batch,
error,
) = await composio_connector.list_gmail_messages(
query=query,
max_results=current_batch_size,
page_token=page_token,
)
if error:
await task_logger.log_task_failure(
log_entry, f"Failed to fetch Gmail messages: {error}", {}
)
return 0, f"Failed to fetch Gmail messages: {error}"
if not messages:
break
if result_size_estimate is None and result_size_estimate_batch is not None:
result_size_estimate = result_size_estimate_batch
logger.info(
f"Gmail API estimated {result_size_estimate} total messages for query: '{query}'"
)
all_messages.extend(messages)
logger.info(
f"Fetched {len(messages)} messages (total: {len(all_messages)})"
)
if not next_token or len(messages) < current_batch_size:
break
page_token = next_token
if not all_messages:
success_msg = "No Gmail messages found in the specified date range"
await task_logger.log_task_success(
log_entry, success_msg, {"messages_count": 0}
)
await update_connector_last_indexed(session, connector, update_last_indexed)
await session.commit()
return (
0,
None,
) # Return None (not error) when no items found - this is success with 0 items
logger.info(f"Found {len(all_messages)} Gmail messages to index via Composio")
# =======================================================================
# PHASE 1: Analyze all messages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
await task_logger.log_task_progress(
log_entry,
f"Phase 1: Creating pending documents for {len(all_messages)} messages",
{"stage": "phase1_pending"},
)
(
messages_to_process,
documents_skipped,
duplicate_content_count,
) = await _analyze_gmail_messages_phase1(
session=session,
messages=all_messages,
composio_connector=composio_connector,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
# Commit all pending documents - they all appear in UI now
new_documents_count = len([m for m in messages_to_process if m["is_new"]])
if new_documents_count > 0:
logger.info(f"Phase 1: Committing {new_documents_count} pending documents")
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
await task_logger.log_task_progress(
log_entry,
f"Phase 2: Processing {len(messages_to_process)} documents",
{"stage": "phase2_processing"},
)
documents_indexed, documents_failed = await _process_gmail_messages_phase2(
session=session,
messages_to_process=messages_to_process,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)
# CRITICAL: Always update timestamp so Electric SQL syncs
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
try:
await session.commit()
logger.info(
"Successfully committed all Composio Gmail document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
await task_logger.log_task_success(
log_entry,
f"Successfully completed Gmail indexing via Composio for connector {connector_id}",
{
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"Composio Gmail indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
return documents_indexed, warning_message
except Exception as e:
logger.error(f"Failed to index Gmail via Composio: {e!s}", exc_info=True)
return 0, f"Failed to index Gmail via Composio: {e!s}"

View file

@ -1,566 +0,0 @@
"""
Composio Google Calendar Connector Module.
Provides Google Calendar specific methods for data retrieval and indexing via Composio.
"""
import logging
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
calculate_date_range,
check_duplicate_document_by_hash,
safe_set_chunks,
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
# Heartbeat configuration
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
logger = logging.getLogger(__name__)
def get_current_timestamp() -> datetime:
"""Get the current timestamp with timezone for updated_at field."""
return datetime.now(UTC)
async def check_document_by_unique_identifier(
session: AsyncSession, unique_identifier_hash: str
) -> Document | None:
"""Check if a document with the given unique identifier hash already exists."""
existing_doc_result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
.where(Document.unique_identifier_hash == unique_identifier_hash)
)
return existing_doc_result.scalars().first()
async def update_connector_last_indexed(
session: AsyncSession,
connector,
update_last_indexed: bool = True,
) -> None:
"""Update the last_indexed_at timestamp for a connector."""
if update_last_indexed:
connector.last_indexed_at = datetime.now(UTC)
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
class ComposioGoogleCalendarConnector(ComposioConnector):
"""
Google Calendar specific Composio connector.
Provides methods for listing calendar events and formatting them from
Google Calendar via Composio.
"""
async def list_calendar_events(
self,
time_min: str | None = None,
time_max: str | None = None,
max_results: int = 250,
) -> tuple[list[dict[str, Any]], str | None]:
"""
List Google Calendar events via Composio.
Args:
time_min: Start time (RFC3339 format).
time_max: End time (RFC3339 format).
max_results: Maximum number of events.
Returns:
Tuple of (events list, error message).
"""
connected_account_id = await self.get_connected_account_id()
if not connected_account_id:
return [], "No connected account ID found"
entity_id = await self.get_entity_id()
service = await self._get_service()
return await service.get_calendar_events(
connected_account_id=connected_account_id,
entity_id=entity_id,
time_min=time_min,
time_max=time_max,
max_results=max_results,
)
def format_calendar_event_to_markdown(self, event: dict[str, Any]) -> str:
"""
Format a Google Calendar event to markdown.
Args:
event: Event object from Google Calendar API.
Returns:
Formatted markdown string.
"""
try:
# Extract basic event information
summary = event.get("summary", "No Title")
description = event.get("description", "")
location = event.get("location", "")
# Extract start and end times
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
# Format times for display
def format_time(time_str: str) -> str:
if not time_str:
return "Unknown"
try:
if "T" in time_str:
dt = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M")
return time_str
except Exception:
return time_str
start_formatted = format_time(start_time)
end_formatted = format_time(end_time)
# Extract attendees
attendees = event.get("attendees", [])
attendee_list = []
for attendee in attendees:
email = attendee.get("email", "")
display_name = attendee.get("displayName", email)
response_status = attendee.get("responseStatus", "")
attendee_list.append(f"- {display_name} ({response_status})")
# Build markdown content
markdown_content = f"# {summary}\n\n"
markdown_content += f"**Start:** {start_formatted}\n"
markdown_content += f"**End:** {end_formatted}\n"
if location:
markdown_content += f"**Location:** {location}\n"
markdown_content += "\n"
if description:
markdown_content += f"## Description\n\n{description}\n\n"
if attendee_list:
markdown_content += "## Attendees\n\n"
markdown_content += "\n".join(attendee_list)
markdown_content += "\n\n"
# Add event metadata
markdown_content += "## Event Details\n\n"
markdown_content += f"- **Event ID:** {event.get('id', 'Unknown')}\n"
markdown_content += f"- **Created:** {event.get('created', 'Unknown')}\n"
markdown_content += f"- **Updated:** {event.get('updated', 'Unknown')}\n"
return markdown_content
except Exception as e:
return f"Error formatting event to markdown: {e!s}"
# ============ Indexer Functions ============
async def index_composio_google_calendar(
session: AsyncSession,
connector,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str | None,
end_date: str | None,
task_logger: TaskLoggingService,
log_entry,
update_last_indexed: bool = True,
max_items: int = 2500,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str]:
"""Index Google Calendar events via Composio."""
try:
composio_connector = ComposioGoogleCalendarConnector(session, connector_id)
await task_logger.log_task_progress(
log_entry,
f"Fetching Google Calendar events via Composio for connector {connector_id}",
{"stage": "fetching_events"},
)
# Normalize date values - handle "undefined" strings from frontend
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
# This ensures user-selected dates are respected (matching non-Composio Calendar connector behavior)
if start_date is not None and end_date is not None:
# User provided both dates - use them directly
start_date_str = start_date
end_date_str = end_date
else:
# Calculate date range with defaults (uses last_indexed_at or 365 days back)
# This ensures indexing works even when user doesn't specify dates
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Build time range for API call
time_min = f"{start_date_str}T00:00:00Z"
time_max = f"{end_date_str}T23:59:59Z"
logger.info(
f"Google Calendar query for connector {connector_id}: "
f"(start_date={start_date_str}, end_date={end_date_str})"
)
events, error = await composio_connector.list_calendar_events(
time_min=time_min,
time_max=time_max,
max_results=max_items,
)
if error:
await task_logger.log_task_failure(
log_entry, f"Failed to fetch Calendar events: {error}", {}
)
return 0, f"Failed to fetch Calendar events: {error}"
if not events:
success_msg = "No Google Calendar events found in the specified date range"
await task_logger.log_task_success(
log_entry, success_msg, {"events_count": 0}
)
# CRITICAL: Update timestamp even when no events found so Electric SQL syncs and UI shows indexed status
await update_connector_last_indexed(session, connector, update_last_indexed)
await session.commit()
return (
0,
None,
) # Return None (not error) when no items found - this is success with 0 items
logger.info(f"Found {len(events)} Google Calendar events to index via Composio")
documents_indexed = 0
documents_skipped = 0
documents_failed = 0 # Track events that failed processing
duplicate_content_count = (
0 # Track events skipped due to duplicate content_hash
)
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all events, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
events_to_process = [] # List of dicts with document and event data
new_documents_created = False
for event in events:
try:
# Handle both standard Google API and potential Composio variations
event_id = event.get("id", "") or event.get("eventId", "")
summary = (
event.get("summary", "") or event.get("title", "") or "No Title"
)
if not event_id:
documents_skipped += 1
continue
# Format to markdown
markdown_content = composio_connector.format_calendar_event_to_markdown(
event
)
# Generate unique identifier
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"])
unique_identifier_hash = generate_unique_identifier_hash(
document_type, f"calendar_{event_id}", search_space_id
)
content_hash = generate_content_hash(markdown_content, search_space_id)
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Extract event times
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
if existing_document:
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
events_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"event_id": event_id,
"summary": summary,
"start_time": start_time,
"end_time": end_time,
"location": location,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from standard connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Event {summary} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=summary,
document_type=DocumentType(
TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"]
),
document_metadata={
"event_id": event_id,
"summary": summary,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
"toolkit_id": "googlecalendar",
"source": "composio",
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
events_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"event_id": event_id,
"summary": summary,
"start_time": start_time,
"end_time": end_time,
"location": location,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
for item in events_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"summary": item["summary"],
"start_time": item["start_time"],
"document_type": "Google Calendar Event (Composio)",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = (
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["summary"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"event_id": item["event_id"],
"summary": item["summary"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"],
"connector_id": connector_id,
"source": "composio",
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Electric SQL syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
# This matches the pattern used in non-Composio Gmail indexer
logger.info(
f"Final commit: Total {documents_indexed} Google Calendar events processed"
)
try:
await session.commit()
logger.info(
"Successfully committed all Composio Google Calendar document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same event was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
await task_logger.log_task_success(
log_entry,
f"Successfully completed Google Calendar indexing via Composio for connector {connector_id}",
{
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"Composio Google Calendar indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
return documents_indexed, warning_message
except Exception as e:
logger.error(
f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True
)
return 0, f"Failed to index Google Calendar via Composio: {e!s}"

View file

@ -14,7 +14,6 @@ from sqlalchemy.future import select
from app.config import config from app.config import config
from app.connectors.confluence_connector import ConfluenceConnector from app.connectors.confluence_connector import ConfluenceConnector
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
from app.routes.confluence_add_connector_route import refresh_confluence_token
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption from app.utils.oauth_security import TokenEncryption
@ -190,7 +189,11 @@ class ConfluenceHistoryConnector:
f"Connector {self._connector_id} not found; cannot refresh token." f"Connector {self._connector_id} not found; cannot refresh token."
) )
# Refresh token # Lazy import to avoid circular dependency
from app.routes.confluence_add_connector_route import (
refresh_confluence_token,
)
connector = await refresh_confluence_token(self._session, connector) connector = await refresh_confluence_token(self._session, connector)
# Reload credentials after refresh # Reload credentials after refresh
@ -341,6 +344,61 @@ class ConfluenceHistoryConnector:
logger.error(f"Confluence API request error: {e!s}", exc_info=True) logger.error(f"Confluence API request error: {e!s}", exc_info=True)
raise Exception(f"Confluence API request failed: {e!s}") from e raise Exception(f"Confluence API request failed: {e!s}") from e
async def _make_api_request_with_method(
self,
endpoint: str,
method: str = "GET",
json_payload: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Make a request to the Confluence API with a specified HTTP method."""
if not self._use_oauth:
raise ValueError("Write operations require OAuth authentication")
token = await self._get_valid_token()
base_url = await self._get_base_url()
http_client = await self._get_client()
url = f"{base_url}/wiki/api/v2/{endpoint}"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Accept": "application/json",
}
try:
method_upper = method.upper()
if method_upper == "POST":
response = await http_client.post(
url, headers=headers, json=json_payload, params=params
)
elif method_upper == "PUT":
response = await http_client.put(
url, headers=headers, json=json_payload, params=params
)
elif method_upper == "DELETE":
response = await http_client.delete(url, headers=headers, params=params)
else:
response = await http_client.get(url, headers=headers, params=params)
response.raise_for_status()
if response.status_code == 204 or not response.text:
return {"status": "success"}
return response.json()
except httpx.HTTPStatusError as e:
error_detail = {
"status_code": e.response.status_code,
"url": str(e.request.url),
"response_text": e.response.text,
}
logger.error(f"Confluence API HTTP error: {error_detail}")
raise Exception(
f"Confluence API request failed (HTTP {e.response.status_code}): {e.response.text}"
) from e
except httpx.RequestError as e:
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
raise Exception(f"Confluence API request failed: {e!s}") from e
async def get_all_spaces(self) -> list[dict[str, Any]]: async def get_all_spaces(self) -> list[dict[str, Any]]:
""" """
Fetch all spaces from Confluence. Fetch all spaces from Confluence.
@ -593,6 +651,65 @@ class ConfluenceHistoryConnector:
except Exception as e: except Exception as e:
return [], f"Error fetching pages: {e!s}" return [], f"Error fetching pages: {e!s}"
async def get_page(self, page_id: str) -> dict[str, Any]:
"""Fetch a single page by ID with body content."""
return await self._make_api_request(
f"pages/{page_id}", params={"body-format": "storage"}
)
async def create_page(
self,
space_id: str,
title: str,
body: str,
parent_page_id: str | None = None,
) -> dict[str, Any]:
"""Create a new Confluence page."""
payload: dict[str, Any] = {
"spaceId": space_id,
"title": title,
"body": {
"representation": "storage",
"value": body,
},
"status": "current",
}
if parent_page_id:
payload["parentId"] = parent_page_id
return await self._make_api_request_with_method(
"pages", method="POST", json_payload=payload
)
async def update_page(
self,
page_id: str,
title: str,
body: str,
version_number: int,
) -> dict[str, Any]:
"""Update an existing Confluence page (requires version number)."""
payload: dict[str, Any] = {
"id": page_id,
"title": title,
"body": {
"representation": "storage",
"value": body,
},
"version": {
"number": version_number,
},
"status": "current",
}
return await self._make_api_request_with_method(
f"pages/{page_id}", method="PUT", json_payload=payload
)
async def delete_page(self, page_id: str) -> dict[str, Any]:
"""Delete a Confluence page."""
return await self._make_api_request_with_method(
f"pages/{page_id}", method="DELETE"
)
async def close(self): async def close(self):
"""Close the HTTP client connection.""" """Close the HTTP client connection."""
if self._http_client: if self._http_client:

View file

@ -52,44 +52,39 @@ class GoogleCalendarConnector:
) -> Credentials: ) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
Returns:
Google OAuth credentials Supports both native OAuth (with refresh_token) and Composio-sourced
Raises: credentials (with refresh_handler). For Composio credentials, validation
ValueError: If credentials have not been set and DB persistence are skipped since Composio manages its own tokens.
Exception: If credential refresh fails
""" """
if not all( has_standard_refresh = bool(self._credentials.refresh_token)
[
self._credentials.client_id, if has_standard_refresh and not all(
self._credentials.client_secret, [self._credentials.client_id, self._credentials.client_secret]
self._credentials.refresh_token,
]
): ):
raise ValueError( raise ValueError(
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set" "Google OAuth credentials (client_id, client_secret) must be set"
) )
if self._credentials and not self._credentials.expired: if self._credentials and not self._credentials.expired:
return self._credentials return self._credentials
# Create credentials from refresh token if has_standard_refresh:
self._credentials = Credentials( self._credentials = Credentials(
token=self._credentials.token, token=self._credentials.token,
refresh_token=self._credentials.refresh_token, refresh_token=self._credentials.refresh_token,
token_uri=self._credentials.token_uri, token_uri=self._credentials.token_uri,
client_id=self._credentials.client_id, client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret, client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes, scopes=self._credentials.scopes,
expiry=self._credentials.expiry, expiry=self._credentials.expiry,
) )
# Refresh the token if needed
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired or not self._credentials.valid:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB # Only persist refreshed token for native OAuth (Composio manages its own)
if self._session: if has_standard_refresh and self._session:
# Use connector_id if available, otherwise fall back to user_id query
if self._connector_id: if self._connector_id:
result = await self._session.execute( result = await self._session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
@ -110,7 +105,6 @@ class GoogleCalendarConnector:
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token." "GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
) )
# Encrypt sensitive credentials before storing
from app.config import config from app.config import config
from app.utils.oauth_security import TokenEncryption from app.utils.oauth_security import TokenEncryption
@ -119,7 +113,6 @@ class GoogleCalendarConnector:
if token_encrypted and config.SECRET_KEY: if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY) token_encryption = TokenEncryption(config.SECRET_KEY)
# Encrypt sensitive fields
if creds_dict.get("token"): if creds_dict.get("token"):
creds_dict["token"] = token_encryption.encrypt_token( creds_dict["token"] = token_encryption.encrypt_token(
creds_dict["token"] creds_dict["token"]
@ -143,7 +136,6 @@ class GoogleCalendarConnector:
await self._session.commit() await self._session.commit()
except Exception as e: except Exception as e:
error_str = str(e) error_str = str(e)
# Check if this is an invalid_grant error (token expired/revoked)
if ( if (
"invalid_grant" in error_str.lower() "invalid_grant" in error_str.lower()
or "token has been expired or revoked" in error_str.lower() or "token has been expired or revoked" in error_str.lower()

View file

@ -3,6 +3,7 @@
import io import io
from typing import Any from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload from googleapiclient.http import MediaIoBaseUpload
@ -15,16 +16,24 @@ from .file_types import GOOGLE_DOC, GOOGLE_SHEET
class GoogleDriveClient: class GoogleDriveClient:
"""Client for Google Drive API operations.""" """Client for Google Drive API operations."""
def __init__(self, session: AsyncSession, connector_id: int): def __init__(
self,
session: AsyncSession,
connector_id: int,
credentials: "Credentials | None" = None,
):
""" """
Initialize Google Drive client. Initialize Google Drive client.
Args: Args:
session: Database session session: Database session
connector_id: ID of the Drive connector connector_id: ID of the Drive connector
credentials: Pre-built credentials (e.g. from Composio). If None,
credentials are loaded from the DB connector config.
""" """
self.session = session self.session = session
self.connector_id = connector_id self.connector_id = connector_id
self._credentials = credentials
self.service = None self.service = None
async def get_service(self): async def get_service(self):
@ -41,7 +50,12 @@ class GoogleDriveClient:
return self.service return self.service
try: try:
credentials = await get_valid_credentials(self.session, self.connector_id) if self._credentials:
credentials = self._credentials
else:
credentials = await get_valid_credentials(
self.session, self.connector_id
)
self.service = build("drive", "v3", credentials=credentials) self.service = build("drive", "v3", credentials=credentials)
return self.service return self.service
except Exception as e: except Exception as e:

View file

@ -26,6 +26,7 @@ async def download_and_process_file(
task_logger: TaskLoggingService, task_logger: TaskLoggingService,
log_entry: Log, log_entry: Log,
connector_id: int | None = None, connector_id: int | None = None,
enable_summary: bool = True,
) -> tuple[Any, str | None, dict[str, Any] | None]: ) -> tuple[Any, str | None, dict[str, Any] | None]:
""" """
Download Google Drive file and process using Surfsense file processors. Download Google Drive file and process using Surfsense file processors.
@ -95,6 +96,7 @@ async def download_and_process_file(
}, },
} }
# Include connector_id for de-indexing support # Include connector_id for de-indexing support
connector_info["enable_summary"] = enable_summary
if connector_id is not None: if connector_id is not None:
connector_info["connector_id"] = connector_id connector_info["connector_id"] = connector_id

View file

@ -81,44 +81,39 @@ class GoogleGmailConnector:
) -> Credentials: ) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
Returns:
Google OAuth credentials Supports both native OAuth (with refresh_token) and Composio-sourced
Raises: credentials (with refresh_handler). For Composio credentials, validation
ValueError: If credentials have not been set and DB persistence are skipped since Composio manages its own tokens.
Exception: If credential refresh fails
""" """
if not all( has_standard_refresh = bool(self._credentials.refresh_token)
[
self._credentials.client_id, if has_standard_refresh and not all(
self._credentials.client_secret, [self._credentials.client_id, self._credentials.client_secret]
self._credentials.refresh_token,
]
): ):
raise ValueError( raise ValueError(
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set" "Google OAuth credentials (client_id, client_secret) must be set"
) )
if self._credentials and not self._credentials.expired: if self._credentials and not self._credentials.expired:
return self._credentials return self._credentials
# Create credentials from refresh token if has_standard_refresh:
self._credentials = Credentials( self._credentials = Credentials(
token=self._credentials.token, token=self._credentials.token,
refresh_token=self._credentials.refresh_token, refresh_token=self._credentials.refresh_token,
token_uri=self._credentials.token_uri, token_uri=self._credentials.token_uri,
client_id=self._credentials.client_id, client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret, client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes, scopes=self._credentials.scopes,
expiry=self._credentials.expiry, expiry=self._credentials.expiry,
) )
# Refresh the token if needed
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired or not self._credentials.valid:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB # Only persist refreshed token for native OAuth (Composio manages its own)
if self._session: if has_standard_refresh and self._session:
# Use connector_id if available, otherwise fall back to user_id query
if self._connector_id: if self._connector_id:
result = await self._session.execute( result = await self._session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
@ -138,12 +133,38 @@ class GoogleGmailConnector:
raise RuntimeError( raise RuntimeError(
"GMAIL connector not found; cannot persist refreshed token." "GMAIL connector not found; cannot persist refreshed token."
) )
connector.config = json.loads(self._credentials.to_json())
from app.config import config
from app.utils.oauth_security import TokenEncryption
creds_dict = json.loads(self._credentials.to_json())
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if creds_dict.get("token"):
creds_dict["token"] = token_encryption.encrypt_token(
creds_dict["token"]
)
if creds_dict.get("refresh_token"):
creds_dict["refresh_token"] = (
token_encryption.encrypt_token(
creds_dict["refresh_token"]
)
)
if creds_dict.get("client_secret"):
creds_dict["client_secret"] = (
token_encryption.encrypt_token(
creds_dict["client_secret"]
)
)
creds_dict["_token_encrypted"] = True
connector.config = creds_dict
flag_modified(connector, "config") flag_modified(connector, "config")
await self._session.commit() await self._session.commit()
except Exception as e: except Exception as e:
error_str = str(e) error_str = str(e)
# Check if this is an invalid_grant error (token expired/revoked)
if ( if (
"invalid_grant" in error_str.lower() "invalid_grant" in error_str.lower()
or "token has been expired or revoked" in error_str.lower() or "token has been expired or revoked" in error_str.lower()

View file

@ -167,14 +167,23 @@ class JiraConnector:
# Use direct base URL (works for both OAuth and legacy) # Use direct base URL (works for both OAuth and legacy)
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}" url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
if method.upper() == "POST": method_upper = method.upper()
if method_upper == "POST":
response = requests.post( response = requests.post(
url, headers=headers, json=json_payload, timeout=500 url, headers=headers, json=json_payload, timeout=500
) )
elif method_upper == "PUT":
response = requests.put(
url, headers=headers, json=json_payload, timeout=500
)
elif method_upper == "DELETE":
response = requests.delete(url, headers=headers, params=params, timeout=500)
else: else:
response = requests.get(url, headers=headers, params=params, timeout=500) response = requests.get(url, headers=headers, params=params, timeout=500)
if response.status_code == 200: if response.status_code in (200, 201, 204):
if response.status_code == 204 or not response.text:
return {"status": "success"}
return response.json() return response.json()
else: else:
raise Exception( raise Exception(
@ -352,6 +361,91 @@ class JiraConnector:
except Exception as e: except Exception as e:
return [], f"Error fetching issues: {e!s}" return [], f"Error fetching issues: {e!s}"
def get_myself(self) -> dict[str, Any]:
"""Fetch the current user's profile (health check)."""
return self.make_api_request("myself")
def get_projects(self) -> list[dict[str, Any]]:
"""Fetch all projects the user has access to."""
result = self.make_api_request("project/search")
return result.get("values", [])
def get_issue_types(self) -> list[dict[str, Any]]:
"""Fetch all issue types."""
return self.make_api_request("issuetype")
def get_priorities(self) -> list[dict[str, Any]]:
"""Fetch all priority levels."""
return self.make_api_request("priority")
def get_issue(self, issue_id_or_key: str) -> dict[str, Any]:
"""Fetch a single issue by ID or key."""
return self.make_api_request(f"issue/{issue_id_or_key}")
def create_issue(
self,
project_key: str,
summary: str,
issue_type: str = "Task",
description: str | None = None,
priority: str | None = None,
assignee_id: str | None = None,
) -> dict[str, Any]:
"""Create a new Jira issue."""
fields: dict[str, Any] = {
"project": {"key": project_key},
"summary": summary,
"issuetype": {"name": issue_type},
}
if description:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [{"type": "text", "text": description}],
}
],
}
if priority:
fields["priority"] = {"name": priority}
if assignee_id:
fields["assignee"] = {"accountId": assignee_id}
return self.make_api_request(
"issue", method="POST", json_payload={"fields": fields}
)
def update_issue(
self, issue_id_or_key: str, fields: dict[str, Any]
) -> dict[str, Any]:
"""Update an existing Jira issue fields."""
return self.make_api_request(
f"issue/{issue_id_or_key}",
method="PUT",
json_payload={"fields": fields},
)
def delete_issue(self, issue_id_or_key: str) -> dict[str, Any]:
"""Delete a Jira issue."""
return self.make_api_request(f"issue/{issue_id_or_key}", method="DELETE")
def get_transitions(self, issue_id_or_key: str) -> list[dict[str, Any]]:
"""Get available transitions for an issue (for status changes)."""
result = self.make_api_request(f"issue/{issue_id_or_key}/transitions")
return result.get("transitions", [])
def transition_issue(
self, issue_id_or_key: str, transition_id: str
) -> dict[str, Any]:
"""Transition an issue to a new status."""
return self.make_api_request(
f"issue/{issue_id_or_key}/transitions",
method="POST",
json_payload={"transition": {"id": transition_id}},
)
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]: def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
""" """
Format an issue for easier consumption. Format an issue for easier consumption.

View file

@ -14,7 +14,6 @@ from sqlalchemy.future import select
from app.config import config from app.config import config
from app.connectors.jira_connector import JiraConnector from app.connectors.jira_connector import JiraConnector
from app.db import SearchSourceConnector from app.db import SearchSourceConnector
from app.routes.jira_add_connector_route import refresh_jira_token
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption from app.utils.oauth_security import TokenEncryption
@ -184,7 +183,9 @@ class JiraHistoryConnector:
f"Connector {self._connector_id} not found; cannot refresh token." f"Connector {self._connector_id} not found; cannot refresh token."
) )
# Refresh token # Lazy import to avoid circular dependency
from app.routes.jira_add_connector_route import refresh_jira_token
connector = await refresh_jira_token(self._session, connector) connector = await refresh_jira_token(self._session, connector)
# Reload credentials after refresh # Reload credentials after refresh

View file

@ -1,12 +1,12 @@
import asyncio import asyncio
import contextlib import contextlib
import logging import logging
import re
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, TypeVar from typing import Any, TypeVar
from notion_client import AsyncClient from notion_client import AsyncClient
from notion_client.errors import APIResponseError from notion_client.errors import APIResponseError
from notion_markdown import to_notion
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
@ -834,106 +834,8 @@ class NotionHistoryConnector:
return None return None
def _markdown_to_blocks(self, markdown: str) -> list[dict[str, Any]]: def _markdown_to_blocks(self, markdown: str) -> list[dict[str, Any]]:
""" """Convert markdown content to Notion blocks using notion-markdown."""
Convert markdown content to Notion blocks. return to_notion(markdown)
This is a simple converter that handles basic markdown.
For more complex markdown, consider using a proper markdown parser.
Args:
markdown: Markdown content
Returns:
List of Notion block objects
"""
blocks = []
lines = markdown.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
# Heading 1
if line.startswith("# "):
blocks.append(
{
"object": "block",
"type": "heading_1",
"heading_1": {
"rich_text": [
{"type": "text", "text": {"content": line[2:]}}
]
},
}
)
# Heading 2
elif line.startswith("## "):
blocks.append(
{
"object": "block",
"type": "heading_2",
"heading_2": {
"rich_text": [
{"type": "text", "text": {"content": line[3:]}}
]
},
}
)
# Heading 3
elif line.startswith("### "):
blocks.append(
{
"object": "block",
"type": "heading_3",
"heading_3": {
"rich_text": [
{"type": "text", "text": {"content": line[4:]}}
]
},
}
)
# Bullet list
elif line.startswith("- ") or line.startswith("* "):
blocks.append(
{
"object": "block",
"type": "bulleted_list_item",
"bulleted_list_item": {
"rich_text": [
{"type": "text", "text": {"content": line[2:]}}
]
},
}
)
# Numbered list
elif match := re.match(r"^(\d+)\.\s+(.*)$", line):
content = match.group(2) # Extract text after "number. "
blocks.append(
{
"object": "block",
"type": "numbered_list_item",
"numbered_list_item": {
"rich_text": [
{"type": "text", "text": {"content": content}}
]
},
}
)
# Regular paragraph
else:
blocks.append(
{
"object": "block",
"type": "paragraph",
"paragraph": {
"rich_text": [{"type": "text", "text": {"content": line}}]
},
}
)
return blocks
async def create_page( async def create_page(
self, title: str, content: str, parent_page_id: str | None = None self, title: str, content: str, parent_page_id: str | None = None

View file

@ -63,6 +63,16 @@ class DocumentType(StrEnum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
# Native Google document types → their legacy Composio equivalents.
# Old documents may still carry the Composio type until they are re-indexed;
# search, browse, and indexing must transparently handle both.
NATIVE_TO_LEGACY_DOCTYPE: dict[str, str] = {
"GOOGLE_DRIVE_FILE": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
"GOOGLE_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
"GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
}
class SearchSourceConnectorType(StrEnum): class SearchSourceConnectorType(StrEnum):
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
TAVILY_API = "TAVILY_API" TAVILY_API = "TAVILY_API"
@ -712,7 +722,7 @@ class ChatComment(BaseModel, TimestampMixin):
nullable=False, nullable=False,
index=True, index=True,
) )
# Denormalized thread_id for efficient Electric SQL subscriptions (one per thread) # Denormalized thread_id for efficient Zero subscriptions (one per thread)
thread_id = Column( thread_id = Column(
Integer, Integer,
ForeignKey("new_chat_threads.id", ondelete="CASCADE"), ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
@ -782,7 +792,7 @@ class ChatCommentMention(BaseModel, TimestampMixin):
class ChatSessionState(BaseModel): class ChatSessionState(BaseModel):
""" """
Tracks real-time session state for shared chat collaboration. Tracks real-time session state for shared chat collaboration.
One record per thread, synced via Electric SQL. One record per thread, synced via Zero.
""" """
__tablename__ = "chat_session_state" __tablename__ = "chat_session_state"

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import contextlib
import time import time
from datetime import datetime from datetime import datetime
@ -157,7 +158,7 @@ class ChucksHybridSearchRetriever:
query_text: str, query_text: str,
top_k: int, top_k: int,
search_space_id: int, search_space_id: int,
document_type: str | None = None, document_type: str | list[str] | None = None,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
query_embedding: list | None = None, query_embedding: list | None = None,
@ -217,18 +218,24 @@ class ChucksHybridSearchRetriever:
func.coalesce(Document.status["state"].astext, "ready") != "deleting", func.coalesce(Document.status["state"].astext, "ready") != "deleting",
] ]
# Add document type filter if provided # Add document type filter if provided (single string or list of strings)
if document_type is not None: if document_type is not None:
# Convert string to enum value if needed type_list = (
if isinstance(document_type, str): document_type if isinstance(document_type, list) else [document_type]
try: )
doc_type_enum = DocumentType[document_type] doc_type_enums = []
base_conditions.append(Document.document_type == doc_type_enum) for dt in type_list:
except KeyError: if isinstance(dt, str):
# If the document type doesn't exist in the enum, return empty results with contextlib.suppress(KeyError):
return [] doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if not doc_type_enums:
return []
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else: else:
base_conditions.append(Document.document_type == document_type) base_conditions.append(Document.document_type.in_(doc_type_enums))
# Add time-based filtering if provided # Add time-based filtering if provided
if start_date is not None: if start_date is not None:

View file

@ -1,3 +1,4 @@
import contextlib
import time import time
from datetime import datetime from datetime import datetime
@ -149,7 +150,7 @@ class DocumentHybridSearchRetriever:
query_text: str, query_text: str,
top_k: int, top_k: int,
search_space_id: int, search_space_id: int,
document_type: str | None = None, document_type: str | list[str] | None = None,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
query_embedding: list | None = None, query_embedding: list | None = None,
@ -197,18 +198,24 @@ class DocumentHybridSearchRetriever:
func.coalesce(Document.status["state"].astext, "ready") != "deleting", func.coalesce(Document.status["state"].astext, "ready") != "deleting",
] ]
# Add document type filter if provided # Add document type filter if provided (single string or list of strings)
if document_type is not None: if document_type is not None:
# Convert string to enum value if needed type_list = (
if isinstance(document_type, str): document_type if isinstance(document_type, list) else [document_type]
try: )
doc_type_enum = DocumentType[document_type] doc_type_enums = []
base_conditions.append(Document.document_type == doc_type_enum) for dt in type_list:
except KeyError: if isinstance(dt, str):
# If the document type doesn't exist in the enum, return empty results with contextlib.suppress(KeyError):
return [] doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if not doc_type_enums:
return []
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else: else:
base_conditions.append(Document.document_type == document_type) base_conditions.append(Document.document_type.in_(doc_type_enums))
# Add time-based filtering if provided # Add time-based filtering if provided
if start_date is not None: if start_date is not None:

View file

@ -80,7 +80,7 @@ router.include_router(model_list_router) # Dynamic LLM model catalogue from Ope
router.include_router(logs_router) router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(circleback_webhook_router) # Circleback meeting webhooks
router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(surfsense_docs_router) # Surfsense documentation for citations
router.include_router(notifications_router) # Notifications with Electric SQL sync router.include_router(notifications_router) # Notifications with Zero sync
router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(composio_router) # Composio OAuth and toolkit management
router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(public_chat_router) # Public chat sharing and cloning
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages

View file

@ -199,7 +199,7 @@ async def airtable_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=airtable_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -316,7 +316,7 @@ async def airtable_callback(
f"Duplicate Airtable connector detected for user {user_id} with email {user_email}" f"Duplicate Airtable connector detected for user {user_id} with email {user_email}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=airtable-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=airtable-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -348,7 +348,7 @@ async def airtable_callback(
# Redirect to the frontend with success params for indexing config # Redirect to the frontend with success params for indexing config
# Using query params to auto-open the popup with config view on new-chat page # Using query params to auto-open the popup with config view on new-chat page
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=airtable-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=airtable-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:

View file

@ -148,7 +148,7 @@ async def clickup_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=clickup_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=clickup_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -326,7 +326,7 @@ async def clickup_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=clickup-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=clickup-connector"
) )
except ValidationError as e: except ValidationError as e:

View file

@ -208,7 +208,7 @@ async def composio_callback(
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=composio_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=composio_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -263,6 +263,15 @@ async def composio_callback(
logger.info( logger.info(
f"Successfully got connected_account_id: {final_connected_account_id}" f"Successfully got connected_account_id: {final_connected_account_id}"
) )
# Wait for Composio to finish exchanging the auth code for tokens.
try:
service.wait_for_connection(final_connected_account_id, timeout=30.0)
except Exception:
logger.warning(
f"wait_for_connection timed out for {final_connected_account_id}, "
"proceeding anyway",
exc_info=True,
)
# Build entity_id for Composio API calls (same format as used in initiate) # Build entity_id for Composio API calls (same format as used in initiate)
entity_id = f"surfsense_{user_id}" entity_id = f"surfsense_{user_id}"
@ -370,7 +379,7 @@ async def composio_callback(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}"
) )
# This is a NEW account - create a new connector # This is a NEW account - create a new connector
@ -399,7 +408,7 @@ async def composio_callback(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={db_connector.id}"
) )
except IntegrityError as e: except IntegrityError as e:
@ -425,6 +434,211 @@ async def composio_callback(
) from e ) from e
COMPOSIO_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
}
@router.get("/auth/composio/connector/reauth")
async def reauth_composio_connector(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""
Initiate Composio re-authentication for an expired connected account.
Uses Composio's refresh API so the same connected_account_id stays valid
after the user completes the OAuth flow again.
Query params:
space_id: Search space ID the connector belongs to
connector_id: ID of the existing Composio connector to re-authenticate
return_url: Optional frontend path to redirect to after completion
"""
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503, detail="Composio integration is not enabled."
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type.in_(COMPOSIO_CONNECTOR_TYPES),
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Composio connector not found or access denied",
)
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
raise HTTPException(
status_code=400,
detail="Composio connected account ID not found. Please reconnect the connector.",
)
# Build callback URL with secure state
state_manager = get_state_manager()
state_encoded = state_manager.generate_secure_state(
space_id,
user.id,
toolkit_id=connector.config.get("toolkit_id", ""),
connector_id=connector_id,
return_url=return_url,
)
callback_base = config.COMPOSIO_REDIRECT_URI
if not callback_base:
backend_url = config.BACKEND_URL or "http://localhost:8000"
callback_base = (
f"{backend_url}/api/v1/auth/composio/connector/reauth/callback"
)
else:
# Replace the normal callback path with the reauth one
callback_base = callback_base.replace(
"/auth/composio/connector/callback",
"/auth/composio/connector/reauth/callback",
)
callback_url = f"{callback_base}?state={state_encoded}"
service = ComposioService()
refresh_result = service.refresh_connected_account(
connected_account_id=connected_account_id,
redirect_url=callback_url,
)
if refresh_result["redirect_url"] is None:
# Token refreshed server-side; clear auth_expired immediately
if connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": False}
flag_modified(connector, "config")
await session.commit()
logger.info(
f"Composio account {connected_account_id} refreshed server-side (no redirect needed)"
)
return {
"success": True,
"message": "Authentication refreshed successfully.",
}
logger.info(f"Initiating Composio re-auth for connector {connector_id}")
return {"auth_url": refresh_result["redirect_url"]}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Composio re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Composio re-auth: {e!s}"
) from e
@router.get("/auth/composio/connector/reauth/callback")
async def composio_reauth_callback(
request: Request,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
"""
Handle Composio re-authentication callback.
Clears the auth_expired flag and redirects the user back to the frontend.
The connected_account_id has not changed Composio refreshed it in place.
"""
try:
if not state:
raise HTTPException(status_code=400, detail="Missing state parameter")
state_manager = get_state_manager()
try:
data = state_manager.validate_state(state)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Invalid state parameter: {e!s}"
) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
reauth_connector_id = data.get("connector_id")
return_url = data.get("return_url")
if not reauth_connector_id:
raise HTTPException(status_code=400, detail="Missing connector_id in state")
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth callback",
)
# Wait for Composio to finish processing new tokens before proceeding.
# Without this, get_access_token() may return stale credentials.
connected_account_id = connector.config.get("composio_connected_account_id")
if connected_account_id:
try:
service = ComposioService()
service.wait_for_connection(connected_account_id, timeout=30.0)
except Exception:
logger.warning(
f"wait_for_connection timed out for connector {reauth_connector_id}, "
"proceeding anyway — tokens may not be ready yet",
exc_info=True,
)
# Clear auth_expired flag
connector.config = {**connector.config, "auth_expired": False}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info(f"Composio re-auth completed for connector {reauth_connector_id}")
if return_url and return_url.startswith("/"):
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{return_url}")
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
connector.config.get("toolkit_id", ""), "composio-connector"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={reauth_connector_id}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in Composio reauth callback: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to complete Composio re-auth: {e!s}"
) from e
@router.get("/connectors/{connector_id}/composio-drive/folders") @router.get("/connectors/{connector_id}/composio-drive/folders")
async def list_composio_drive_folders( async def list_composio_drive_folders(
connector_id: int, connector_id: int,
@ -433,31 +647,23 @@ async def list_composio_drive_folders(
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
""" """
List folders AND files in user's Google Drive via Composio with hierarchical support. List folders AND files in user's Google Drive via Composio.
This is called at index time from the manage connector page to display Uses the same GoogleDriveClient / list_folder_contents path as the native
the complete file system (folders and files). Only folders are selectable. connector, with Composio-sourced credentials. This means auth errors
propagate identically (Google returns 401 exception auth_expired flag).
Args:
connector_id: ID of the Composio Google Drive connector
parent_id: Optional parent folder ID to list contents (None for root)
Returns:
JSON with list of items: {
"items": [
{"id": str, "name": str, "mimeType": str, "isFolder": bool, ...},
...
]
}
""" """
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
from app.utils.google_credentials import build_composio_credentials
if not ComposioService.is_enabled(): if not ComposioService.is_enabled():
raise HTTPException( raise HTTPException(
status_code=503, status_code=503,
detail="Composio integration is not enabled.", detail="Composio integration is not enabled.",
) )
connector = None
try: try:
# Get connector and verify ownership
result = await session.execute( result = await session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id, SearchSourceConnector.id == connector_id,
@ -474,7 +680,6 @@ async def list_composio_drive_folders(
detail="Composio Google Drive connector not found or access denied", detail="Composio Google Drive connector not found or access denied",
) )
# Get Composio connected account ID from config
composio_connected_account_id = connector.config.get( composio_connected_account_id = connector.config.get(
"composio_connected_account_id" "composio_connected_account_id"
) )
@ -484,63 +689,43 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.", detail="Composio connected account not found. Please reconnect the connector.",
) )
# Initialize Composio service and fetch files credentials = build_composio_credentials(composio_connected_account_id)
service = ComposioService() drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
entity_id = f"surfsense_{user.id}"
# Fetch files/folders from Composio Google Drive items, error = await list_folder_contents(drive_client, parent_id=parent_id)
files, _next_token, error = await service.get_drive_files(
connected_account_id=composio_connected_account_id,
entity_id=entity_id,
folder_id=parent_id,
page_size=100,
)
if error: if error:
logger.error(f"Failed to list Composio Drive files: {error}") error_lower = error.lower()
if (
"401" in error
or "invalid_grant" in error_lower
or "token has been expired or revoked" in error_lower
or "invalid credentials" in error_lower
or "authentication failed" in error_lower
):
try:
if connector and not connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
logger.info(
f"Marked Composio connector {connector_id} as auth_expired"
)
except Exception:
logger.warning(
f"Failed to persist auth_expired for connector {connector_id}",
exc_info=True,
)
raise HTTPException(
status_code=400,
detail="Google Drive authentication expired. Please re-authenticate.",
)
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to list folder contents: {error}" status_code=500, detail=f"Failed to list folder contents: {error}"
) )
# Transform files to match the expected format with isFolder field folder_count = sum(1 for item in items if item.get("isFolder", False))
items = [] file_count = len(items) - folder_count
for file_info in files:
file_id = file_info.get("id", "") or file_info.get("fileId", "")
file_name = (
file_info.get("name", "") or file_info.get("fileName", "") or "Untitled"
)
mime_type = file_info.get("mimeType", "") or file_info.get("mime_type", "")
if not file_id:
continue
is_folder = mime_type == "application/vnd.google-apps.folder"
items.append(
{
"id": file_id,
"name": file_name,
"mimeType": mime_type,
"isFolder": is_folder,
"parents": file_info.get("parents", []),
"size": file_info.get("size"),
"iconLink": file_info.get("iconLink"),
}
)
# Sort: folders first, then files, both alphabetically
folders = sorted(
[item for item in items if item["isFolder"]],
key=lambda x: x["name"].lower(),
)
files_list = sorted(
[item for item in items if not item["isFolder"]],
key=lambda x: x["name"].lower(),
)
items = folders + files_list
folder_count = len(folders)
file_count = len(files_list)
logger.info( logger.info(
f"Listed {len(items)} total items ({folder_count} folders, {file_count} files) for Composio connector {connector_id}" f"Listed {len(items)} total items ({folder_count} folders, {file_count} files) for Composio connector {connector_id}"
@ -553,6 +738,31 @@ async def list_composio_drive_folders(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error listing Composio Drive contents: {e!s}", exc_info=True) logger.error(f"Error listing Composio Drive contents: {e!s}", exc_info=True)
error_lower = str(e).lower()
if (
"invalid_grant" in error_lower
or "token has been expired or revoked" in error_lower
or "invalid credentials" in error_lower
or "authentication failed" in error_lower
or "401" in str(e)
):
try:
if connector and not connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
logger.info(
f"Marked Composio connector {connector_id} as auth_expired"
)
except Exception:
logger.warning(
f"Failed to persist auth_expired for connector {connector_id}",
exc_info=True,
)
raise HTTPException(
status_code=400,
detail="Google Drive authentication expired. Please re-authenticate.",
) from e
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to list Drive contents: {e!s}" status_code=500, detail=f"Failed to list Drive contents: {e!s}"
) from e ) from e

View file

@ -46,6 +46,8 @@ SCOPES = [
"read:space:confluence", "read:space:confluence",
"read:page:confluence", "read:page:confluence",
"read:comment:confluence", "read:comment:confluence",
"write:page:confluence", # Required for creating/updating pages
"delete:page:confluence", # Required for deleting pages
"offline_access", # Required for refresh tokens "offline_access", # Required for refresh tokens
] ]
@ -170,7 +172,7 @@ async def confluence_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=confluence_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=confluence_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -196,6 +198,8 @@ async def confluence_callback(
user_id = UUID(data["user_id"]) user_id = UUID(data["user_id"])
space_id = data["space_id"] space_id = data["space_id"]
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
# Validate redirect URI (security: ensure it matches configured value) # Validate redirect URI (security: ensure it matches configured value)
if not config.CONFLUENCE_REDIRECT_URI: if not config.CONFLUENCE_REDIRECT_URI:
@ -292,6 +296,46 @@ async def confluence_callback(
"_token_encrypted": True, "_token_encrypted": True,
} }
# Handle re-authentication: update existing connector instead of creating new one
if reauth_connector_id:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
result = await session.execute(
sa_select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = {
**connector_config,
"auth_expired": False,
}
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Confluence connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=confluence-connector"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=confluence-connector"
)
# Extract unique identifier from connector credentials # Extract unique identifier from connector credentials
connector_identifier = extract_identifier_from_credentials( connector_identifier = extract_identifier_from_credentials(
SearchSourceConnectorType.CONFLUENCE_CONNECTOR, connector_config SearchSourceConnectorType.CONFLUENCE_CONNECTOR, connector_config
@ -310,7 +354,7 @@ async def confluence_callback(
f"Duplicate Confluence connector detected for user {user_id} with instance {connector_identifier}" f"Duplicate Confluence connector detected for user {user_id} with instance {connector_identifier}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=confluence-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=confluence-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -341,7 +385,7 @@ async def confluence_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=confluence-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=confluence-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:
@ -372,6 +416,73 @@ async def confluence_callback(
) from e ) from e
@router.get("/auth/confluence/connector/reauth")
async def reauth_confluence(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
try:
from sqlalchemy.future import select
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Confluence connector not found or access denied",
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
from urllib.parse import urlencode
auth_params = {
"audience": "api.atlassian.com",
"client_id": config.ATLASSIAN_CLIENT_ID,
"scope": " ".join(SCOPES),
"redirect_uri": config.CONFLUENCE_REDIRECT_URI,
"state": state_encoded,
"response_type": "code",
"prompt": "consent",
}
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
logger.info(
f"Initiating Confluence re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Confluence re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Confluence re-auth: {e!s}"
) from e
async def refresh_confluence_token( async def refresh_confluence_token(
session: AsyncSession, connector: SearchSourceConnector session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector: ) -> SearchSourceConnector:

View file

@ -172,7 +172,7 @@ async def discord_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=discord_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=discord_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -311,7 +311,7 @@ async def discord_callback(
f"Duplicate Discord connector detected for user {user_id} with server {connector_identifier}" f"Duplicate Discord connector detected for user {user_id} with server {connector_identifier}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=discord-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=discord-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -342,7 +342,7 @@ async def discord_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=discord-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=discord-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:

View file

@ -128,7 +128,7 @@ async def create_documents_file_upload(
Upload files as documents with real-time status tracking. Upload files as documents with real-time status tracking.
Implements 2-phase document status updates for real-time UI feedback: Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately via ElectricSQL) - Phase 1: Create all documents with 'pending' status (visible in UI immediately via Zero)
- Phase 2: Celery processes each file: pending processing ready/failed - Phase 2: Celery processes each file: pending processing ready/failed
Requires DOCUMENTS_CREATE permission. Requires DOCUMENTS_CREATE permission.

View file

@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from google_auth_oauthlib.flow import Flow from google_auth_oauthlib.flow import Flow
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email from app.connectors.google_gmail_connector import fetch_google_user_email
@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"] SCOPES = ["https://www.googleapis.com/auth/calendar.events"]
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
# Initialize security utilities # Initialize security utilities
@ -111,6 +113,66 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
) from e ) from e
@router.get("/auth/google/calendar/connector/reauth")
async def reauth_calendar(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Google Calendar re-authentication for an existing connector."""
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Google Calendar connector not found or access denied",
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
flow = get_google_flow()
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
auth_url, _ = flow.authorization_url(
access_type="offline",
prompt="consent",
include_granted_scopes="true",
state=state_encoded,
)
logger.info(
f"Initiating Google Calendar re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Calendar re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Calendar re-auth: {e!s}"
) from e
@router.get("/auth/google/calendar/connector/callback") @router.get("/auth/google/calendar/connector/callback")
async def calendar_callback( async def calendar_callback(
request: Request, request: Request,
@ -137,7 +199,7 @@ async def calendar_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_calendar_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -197,6 +259,42 @@ async def calendar_callback(
# Mark that credentials are encrypted for backward compatibility # Mark that credentials are encrypted for backward compatibility
creds_dict["_token_encrypted"] = True creds_dict["_token_encrypted"] = True
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = {**creds_dict}
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Calendar connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
)
# Check for duplicate connector (same account already connected) # Check for duplicate connector (same account already connected)
is_duplicate = await check_duplicate_connector( is_duplicate = await check_duplicate_connector(
session, session,
@ -210,7 +308,7 @@ async def calendar_callback(
f"Duplicate Google Calendar connector detected for user {user_id} with email {user_email}" f"Duplicate Google Calendar connector detected for user {user_id} with email {user_email}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-calendar-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-calendar-connector"
) )
try: try:
@ -236,7 +334,7 @@ async def calendar_callback(
# Redirect to the frontend with success params for indexing config # Redirect to the frontend with success params for indexing config
# Using query params to auto-open the popup with config view on new-chat page # Using query params to auto-open the popup with config view on new-chat page
return RedirectResponse( return RedirectResponse(
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-calendar-connector&connectorId={db_connector.id}" f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
) )
except ValidationError as e: except ValidationError as e:
await session.rollback() await session.rollback()

View file

@ -257,7 +257,7 @@ async def drive_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_drive_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -345,6 +345,7 @@ async def drive_callback(
db_connector.config = { db_connector.config = {
**creds_dict, **creds_dict,
"start_page_token": existing_start_page_token, "start_page_token": existing_start_page_token,
"auth_expired": False,
} }
from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.orm.attributes import flag_modified
@ -360,7 +361,7 @@ async def drive_callback(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
) )
is_duplicate = await check_duplicate_connector( is_duplicate = await check_duplicate_connector(
@ -375,7 +376,7 @@ async def drive_callback(
f"Duplicate Google Drive connector detected for user {user_id} with email {user_email}" f"Duplicate Google Drive connector detected for user {user_id} with email {user_email}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-drive-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-drive-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -425,7 +426,7 @@ async def drive_callback(
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
) )
except HTTPException: except HTTPException:
@ -502,11 +503,35 @@ async def list_google_drive_folders(
items, error = await list_folder_contents(drive_client, parent_id=parent_id) items, error = await list_folder_contents(drive_client, parent_id=parent_id)
if error: if error:
error_lower = error.lower()
if (
"401" in error
or "invalid_grant" in error_lower
or "token has been expired or revoked" in error_lower
or "invalid credentials" in error_lower
or "authentication failed" in error_lower
):
from sqlalchemy.orm.attributes import flag_modified
try:
if connector and not connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
logger.info(f"Marked connector {connector_id} as auth_expired")
except Exception:
logger.warning(
f"Failed to persist auth_expired for connector {connector_id}",
exc_info=True,
)
raise HTTPException(
status_code=400,
detail="Google Drive authentication expired. Please re-authenticate.",
)
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to list folder contents: {error}" status_code=500, detail=f"Failed to list folder contents: {error}"
) )
# Count folders and files for better logging
folder_count = sum(1 for item in items if item.get("isFolder", False)) folder_count = sum(1 for item in items if item.get("isFolder", False))
file_count = len(items) - folder_count file_count = len(items) - folder_count
@ -515,7 +540,6 @@ async def list_google_drive_folders(
+ (f" in folder {parent_id}" if parent_id else " in ROOT") + (f" in folder {parent_id}" if parent_id else " in ROOT")
) )
# Log first few items for debugging
if items: if items:
logger.info(f"First 3 items: {[item.get('name') for item in items[:3]]}") logger.info(f"First 3 items: {[item.get('name') for item in items[:3]]}")
@ -525,6 +549,31 @@ async def list_google_drive_folders(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error listing Drive contents: {e!s}", exc_info=True) logger.error(f"Error listing Drive contents: {e!s}", exc_info=True)
error_lower = str(e).lower()
if (
"401" in str(e)
or "invalid_grant" in error_lower
or "token has been expired or revoked" in error_lower
or "invalid credentials" in error_lower
or "authentication failed" in error_lower
):
from sqlalchemy.orm.attributes import flag_modified
try:
if connector and not connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
logger.info(f"Marked connector {connector_id} as auth_expired")
except Exception:
logger.warning(
f"Failed to persist auth_expired for connector {connector_id}",
exc_info=True,
)
raise HTTPException(
status_code=400,
detail="Google Drive authentication expired. Please re-authenticate.",
) from e
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to list Drive contents: {e!s}" status_code=500, detail=f"Failed to list Drive contents: {e!s}"
) from e ) from e

View file

@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from google_auth_oauthlib.flow import Flow from google_auth_oauthlib.flow import Flow
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email from app.connectors.google_gmail_connector import fetch_google_user_email
@ -71,7 +73,7 @@ def get_google_flow():
} }
}, },
scopes=[ scopes=[
"https://www.googleapis.com/auth/gmail.readonly", "https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.profile",
"openid", "openid",
@ -129,6 +131,66 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
) from e ) from e
@router.get("/auth/google/gmail/connector/reauth")
async def reauth_gmail(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Gmail re-authentication for an existing connector."""
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Gmail connector not found or access denied",
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
flow = get_google_flow()
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
auth_url, _ = flow.authorization_url(
access_type="offline",
prompt="consent",
include_granted_scopes="true",
state=state_encoded,
)
logger.info(
f"Initiating Gmail re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Gmail re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Gmail re-auth: {e!s}"
) from e
@router.get("/auth/google/gmail/connector/callback") @router.get("/auth/google/gmail/connector/callback")
async def gmail_callback( async def gmail_callback(
request: Request, request: Request,
@ -168,7 +230,7 @@ async def gmail_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_gmail_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -228,6 +290,42 @@ async def gmail_callback(
# Mark that credentials are encrypted for backward compatibility # Mark that credentials are encrypted for backward compatibility
creds_dict["_token_encrypted"] = True creds_dict["_token_encrypted"] = True
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = {**creds_dict}
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Gmail connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
)
# Check for duplicate connector (same account already connected) # Check for duplicate connector (same account already connected)
is_duplicate = await check_duplicate_connector( is_duplicate = await check_duplicate_connector(
session, session,
@ -241,7 +339,7 @@ async def gmail_callback(
f"Duplicate Gmail connector detected for user {user_id} with email {user_email}" f"Duplicate Gmail connector detected for user {user_id} with email {user_email}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-gmail-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-gmail-connector"
) )
try: try:
@ -272,7 +370,7 @@ async def gmail_callback(
# Redirect to the frontend with success params for indexing config # Redirect to the frontend with success params for indexing config
# Using query params to auto-open the popup with config view on new-chat page # Using query params to auto-open the popup with config view on new-chat page
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-gmail-connector&connectorId={db_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
) )
except IntegrityError as e: except IntegrityError as e:

View file

@ -45,6 +45,7 @@ ACCESSIBLE_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-res
SCOPES = [ SCOPES = [
"read:jira-work", "read:jira-work",
"read:jira-user", "read:jira-user",
"write:jira-work", # Required for creating/updating/deleting issues
"offline_access", # Required for refresh tokens "offline_access", # Required for refresh tokens
] ]
@ -167,7 +168,7 @@ async def jira_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=jira_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=jira_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -193,6 +194,8 @@ async def jira_callback(
user_id = UUID(data["user_id"]) user_id = UUID(data["user_id"])
space_id = data["space_id"] space_id = data["space_id"]
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
# Validate redirect URI (security: ensure it matches configured value) # Validate redirect URI (security: ensure it matches configured value)
if not config.JIRA_REDIRECT_URI: if not config.JIRA_REDIRECT_URI:
@ -310,6 +313,46 @@ async def jira_callback(
"_token_encrypted": True, "_token_encrypted": True,
} }
# Handle re-authentication: update existing connector instead of creating new one
if reauth_connector_id:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
result = await session.execute(
sa_select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = {
**connector_config,
"auth_expired": False,
}
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Jira connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=jira-connector"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=jira-connector"
)
# Extract unique identifier from connector credentials # Extract unique identifier from connector credentials
connector_identifier = extract_identifier_from_credentials( connector_identifier = extract_identifier_from_credentials(
SearchSourceConnectorType.JIRA_CONNECTOR, connector_config SearchSourceConnectorType.JIRA_CONNECTOR, connector_config
@ -328,7 +371,7 @@ async def jira_callback(
f"Duplicate Jira connector detected for user {user_id} with instance {connector_identifier}" f"Duplicate Jira connector detected for user {user_id} with instance {connector_identifier}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=jira-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=jira-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -359,7 +402,7 @@ async def jira_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=jira-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=jira-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:
@ -390,6 +433,73 @@ async def jira_callback(
) from e ) from e
@router.get("/auth/jira/connector/reauth")
async def reauth_jira(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
try:
from sqlalchemy.future import select
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Jira connector not found or access denied",
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
from urllib.parse import urlencode
auth_params = {
"audience": "api.atlassian.com",
"client_id": config.ATLASSIAN_CLIENT_ID,
"scope": " ".join(SCOPES),
"redirect_uri": config.JIRA_REDIRECT_URI,
"state": state_encoded,
"response_type": "code",
"prompt": "consent",
}
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
logger.info(
f"Initiating Jira re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Jira re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Jira re-auth: {e!s}"
) from e
async def refresh_jira_token( async def refresh_jira_token(
session: AsyncSession, connector: SearchSourceConnector session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector: ) -> SearchSourceConnector:

View file

@ -12,8 +12,10 @@ import httpx
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config from app.config import config
from app.connectors.linear_connector import fetch_linear_organization_name from app.connectors.linear_connector import fetch_linear_organization_name
@ -127,6 +129,70 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
) from e ) from e
@router.get("/auth/linear/connector/reauth")
async def reauth_linear(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Linear re-authentication for an existing connector."""
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Linear connector not found or access denied",
)
if not config.LINEAR_CLIENT_ID:
raise HTTPException(status_code=500, detail="Linear OAuth not configured.")
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
from urllib.parse import urlencode
auth_params = {
"client_id": config.LINEAR_CLIENT_ID,
"response_type": "code",
"redirect_uri": config.LINEAR_REDIRECT_URI,
"scope": " ".join(SCOPES),
"state": state_encoded,
}
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
logger.info(
f"Initiating Linear re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Linear re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Linear re-auth: {e!s}"
) from e
@router.get("/auth/linear/connector/callback") @router.get("/auth/linear/connector/callback")
async def linear_callback( async def linear_callback(
request: Request, request: Request,
@ -166,7 +232,7 @@ async def linear_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=linear_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -267,6 +333,43 @@ async def linear_callback(
"_token_encrypted": True, "_token_encrypted": True,
} }
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
connector_config["organization_name"] = org_name
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Linear connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={db_connector.id}"
)
# Check for duplicate connector (same organization already connected) # Check for duplicate connector (same organization already connected)
is_duplicate = await check_duplicate_connector( is_duplicate = await check_duplicate_connector(
session, session,
@ -280,7 +383,7 @@ async def linear_callback(
f"Duplicate Linear connector detected for user {user_id} with org {org_name}" f"Duplicate Linear connector detected for user {user_id} with org {org_name}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=linear-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=linear-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -292,6 +395,7 @@ async def linear_callback(
org_name, org_name,
) )
# Create new connector # Create new connector
connector_config["organization_name"] = org_name
new_connector = SearchSourceConnector( new_connector = SearchSourceConnector(
name=connector_name, name=connector_name,
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR, connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
@ -311,7 +415,7 @@ async def linear_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:
@ -342,6 +446,22 @@ async def linear_callback(
) from e ) from e
async def _mark_connector_auth_expired(
session: AsyncSession, connector: SearchSourceConnector
) -> None:
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
except Exception:
logger.warning(
f"Failed to persist auth_expired flag for connector {connector.id}",
exc_info=True,
)
async def refresh_linear_token( async def refresh_linear_token(
session: AsyncSession, connector: SearchSourceConnector session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector: ) -> SearchSourceConnector:
@ -375,6 +495,7 @@ async def refresh_linear_token(
) from e ) from e
if not refresh_token: if not refresh_token:
await _mark_connector_auth_expired(session, connector)
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="No refresh token available. Please re-authenticate.", detail="No refresh token available. Please re-authenticate.",
@ -417,6 +538,7 @@ async def refresh_linear_token(
or "expired" in error_lower or "expired" in error_lower
or "revoked" in error_lower or "revoked" in error_lower
): ):
await _mark_connector_auth_expired(session, connector)
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Linear authentication failed. Please re-authenticate.", detail="Linear authentication failed. Please re-authenticate.",
@ -453,10 +575,16 @@ async def refresh_linear_token(
credentials.expires_at = expires_at credentials.expires_at = expires_at
credentials.scope = token_json.get("scope") credentials.scope = token_json.get("scope")
# Update connector config with encrypted tokens # Update connector config with encrypted tokens, preserving non-credential fields
credentials_dict = credentials.to_dict() credentials_dict = credentials.to_dict()
credentials_dict["_token_encrypted"] = True credentials_dict["_token_encrypted"] = True
if connector.config.get("organization_name"):
credentials_dict["organization_name"] = connector.config[
"organization_name"
]
credentials_dict.pop("auth_expired", None)
connector.config = credentials_dict connector.config = credentials_dict
flag_modified(connector, "config")
await session.commit() await session.commit()
await session.refresh(connector) await session.refresh(connector)

View file

@ -1,7 +1,7 @@
""" """
Notifications API routes. Notifications API routes.
These endpoints allow marking notifications as read and fetching older notifications. These endpoints allow marking notifications as read and fetching older notifications.
Electric SQL automatically syncs the changes to all connected clients for recent items. Zero automatically syncs the changes to all connected clients for recent items.
For older items (beyond the sync window), use the list endpoint. For older items (beyond the sync window), use the list endpoint.
""" """
@ -267,7 +267,7 @@ async def get_unread_count(
This allows the frontend to calculate: This allows the frontend to calculate:
- older_unread = total_unread - recent_unread (static until reconciliation) - older_unread = total_unread - recent_unread (static until reconciliation)
- Display count = older_unread + live_recent_count (from Electric SQL) - Display count = older_unread + live_recent_count (from Zero)
""" """
# Calculate cutoff date for sync window # Calculate cutoff date for sync window
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
@ -344,7 +344,7 @@ async def list_notifications(
List notifications for the current user with pagination. List notifications for the current user with pagination.
This endpoint is used as a fallback for older notifications that are This endpoint is used as a fallback for older notifications that are
outside the Electric SQL sync window (2 weeks). outside the Zero sync window (2 weeks).
Use `before_date` to paginate through older notifications efficiently. Use `before_date` to paginate through older notifications efficiently.
""" """
@ -487,7 +487,7 @@ async def mark_notification_as_read(
""" """
Mark a single notification as read. Mark a single notification as read.
Electric SQL will automatically sync this change to all connected clients. Zero will automatically sync this change to all connected clients.
""" """
# Verify the notification belongs to the user # Verify the notification belongs to the user
result = await session.execute( result = await session.execute(
@ -528,7 +528,7 @@ async def mark_all_notifications_as_read(
""" """
Mark all notifications as read for the current user. Mark all notifications as read for the current user.
Electric SQL will automatically sync these changes to all connected clients. Zero will automatically sync these changes to all connected clients.
""" """
# Update all unread notifications for the user # Update all unread notifications for the user
result = await session.execute( result = await session.execute(

View file

@ -12,8 +12,10 @@ import httpx
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config from app.config import config
from app.db import ( from app.db import (
@ -124,6 +126,70 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
) from e ) from e
@router.get("/auth/notion/connector/reauth")
async def reauth_notion(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Notion re-authentication for an existing connector."""
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Notion connector not found or access denied",
)
if not config.NOTION_CLIENT_ID:
raise HTTPException(status_code=500, detail="Notion OAuth not configured.")
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
from urllib.parse import urlencode
auth_params = {
"client_id": config.NOTION_CLIENT_ID,
"response_type": "code",
"owner": "user",
"redirect_uri": config.NOTION_REDIRECT_URI,
"state": state_encoded,
}
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
logger.info(
f"Initiating Notion re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Notion re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Notion re-auth: {e!s}"
) from e
@router.get("/auth/notion/connector/callback") @router.get("/auth/notion/connector/callback")
async def notion_callback( async def notion_callback(
request: Request, request: Request,
@ -163,7 +229,7 @@ async def notion_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=notion_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -266,6 +332,42 @@ async def notion_callback(
"_token_encrypted": True, "_token_encrypted": True,
} }
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Notion connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={db_connector.id}"
)
# Extract unique identifier from connector credentials # Extract unique identifier from connector credentials
connector_identifier = extract_identifier_from_credentials( connector_identifier = extract_identifier_from_credentials(
SearchSourceConnectorType.NOTION_CONNECTOR, connector_config SearchSourceConnectorType.NOTION_CONNECTOR, connector_config
@ -284,7 +386,7 @@ async def notion_callback(
f"Duplicate Notion connector detected for user {user_id} with workspace {connector_identifier}" f"Duplicate Notion connector detected for user {user_id} with workspace {connector_identifier}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=notion-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=notion-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -315,7 +417,7 @@ async def notion_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:
@ -346,6 +448,22 @@ async def notion_callback(
) from e ) from e
async def _mark_connector_auth_expired(
session: AsyncSession, connector: SearchSourceConnector
) -> None:
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
except Exception:
logger.warning(
f"Failed to persist auth_expired flag for connector {connector.id}",
exc_info=True,
)
async def refresh_notion_token( async def refresh_notion_token(
session: AsyncSession, connector: SearchSourceConnector session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector: ) -> SearchSourceConnector:
@ -379,6 +497,7 @@ async def refresh_notion_token(
) from e ) from e
if not refresh_token: if not refresh_token:
await _mark_connector_auth_expired(session, connector)
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="No refresh token available. Please re-authenticate.", detail="No refresh token available. Please re-authenticate.",
@ -421,6 +540,7 @@ async def refresh_notion_token(
or "expired" in error_lower or "expired" in error_lower
or "revoked" in error_lower or "revoked" in error_lower
): ):
await _mark_connector_auth_expired(session, connector)
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Notion authentication failed. Please re-authenticate.", detail="Notion authentication failed. Please re-authenticate.",
@ -469,7 +589,9 @@ async def refresh_notion_token(
# Update connector config with encrypted tokens # Update connector config with encrypted tokens
credentials_dict = credentials.to_dict() credentials_dict = credentials.to_dict()
credentials_dict["_token_encrypted"] = True credentials_dict["_token_encrypted"] = True
credentials_dict.pop("auth_expired", None)
connector.config = credentials_dict connector.config = credentials_dict
flag_modified(connector, "config")
await session.commit() await session.commit()
await session.refresh(connector) await session.refresh(connector)

View file

@ -72,6 +72,7 @@ from app.tasks.connector_indexers import (
index_slack_messages, index_slack_messages,
) )
from app.users import current_active_user from app.users import current_active_user
from app.utils.connector_naming import ensure_unique_connector_name
from app.utils.indexing_locks import ( from app.utils.indexing_locks import (
acquire_connector_indexing_lock, acquire_connector_indexing_lock,
release_connector_indexing_lock, release_connector_indexing_lock,
@ -189,6 +190,12 @@ async def create_search_source_connector(
# Prepare connector data # Prepare connector data
connector_data = connector.model_dump() connector_data = connector.model_dump()
# MCP connectors support multiple instances — ensure unique name
if connector.connector_type == SearchSourceConnectorType.MCP_CONNECTOR:
connector_data["name"] = await ensure_unique_connector_name(
session, connector_data["name"], search_space_id, user.id
)
# Automatically set next_scheduled_at if periodic indexing is enabled # Automatically set next_scheduled_at if periodic indexing is enabled
if ( if (
connector.periodic_indexing_enabled connector.periodic_indexing_enabled
@ -949,23 +956,46 @@ async def index_connector_content(
index_google_drive_files_task, index_google_drive_files_task,
) )
if not drive_items or not drive_items.has_items(): if drive_items and drive_items.has_items():
raise HTTPException( logger.info(
status_code=400, f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
detail="Google Drive indexing requires drive_items body parameter with folders or files", f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
)
items_dict = drive_items.model_dump()
else:
# Quick Index / periodic sync: fall back to stored config
config = connector.config or {}
selected_folders = config.get("selected_folders", [])
selected_files = config.get("selected_files", [])
if not selected_folders and not selected_files:
raise HTTPException(
status_code=400,
detail="Google Drive indexing requires folders or files to be configured. "
"Please select folders/files to index.",
)
indexing_options = config.get(
"indexing_options",
{
"max_files_per_folder": 100,
"incremental_sync": True,
"include_subfolders": True,
},
)
items_dict = {
"folders": selected_folders,
"files": selected_files,
"indexing_options": indexing_options,
}
logger.info(
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id} "
f"using existing config"
) )
logger.info(
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
)
# Pass structured data to Celery task
index_google_drive_files_task.delay( index_google_drive_files_task.delay(
connector_id, connector_id,
search_space_id, search_space_id,
str(user.id), str(user.id),
drive_items.model_dump(), # Convert to dict for JSON serialization items_dict,
) )
response_message = "Google Drive indexing started in the background." response_message = "Google Drive indexing started in the background."
@ -1061,7 +1091,7 @@ async def index_connector_content(
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
): ):
from app.tasks.celery_tasks.connector_tasks import ( from app.tasks.celery_tasks.connector_tasks import (
index_composio_connector_task, index_google_drive_files_task,
) )
# For Composio Google Drive, if drive_items is provided, update connector config # For Composio Google Drive, if drive_items is provided, update connector config
@ -1095,34 +1125,72 @@ async def index_connector_content(
else: else:
logger.info( logger.info(
f"Triggering Composio Google Drive indexing for connector {connector_id} into search space {search_space_id} " f"Triggering Composio Google Drive indexing for connector {connector_id} into search space {search_space_id} "
f"using existing config (from {indexing_from} to {indexing_to})" f"using existing config"
) )
index_composio_connector_task.delay( # Extract config and build items_dict for index_google_drive_files_task
connector_id, search_space_id, str(user.id), indexing_from, indexing_to config = connector.config or {}
selected_folders = config.get("selected_folders", [])
selected_files = config.get("selected_files", [])
if not selected_folders and not selected_files:
raise HTTPException(
status_code=400,
detail="Composio Google Drive indexing requires folders or files to be configured. "
"Please select folders/files to index.",
)
indexing_options = config.get(
"indexing_options",
{
"max_files_per_folder": 100,
"incremental_sync": True,
"include_subfolders": True,
},
)
items_dict = {
"folders": selected_folders,
"files": selected_files,
"indexing_options": indexing_options,
}
index_google_drive_files_task.delay(
connector_id, search_space_id, str(user.id), items_dict
) )
response_message = ( response_message = (
"Composio Google Drive indexing started in the background." "Composio Google Drive indexing started in the background."
) )
elif connector.connector_type in [ elif (
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, connector.connector_type
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
]: ):
from app.tasks.celery_tasks.connector_tasks import ( from app.tasks.celery_tasks.connector_tasks import (
index_composio_connector_task, index_google_gmail_messages_task,
) )
# For Composio Gmail and Calendar, use the same date calculation logic as normal connectors
# This ensures consistent behavior and uses last_indexed_at to reduce API calls
# (includes special case: if indexed today, go back 1 day to avoid missing data)
logger.info( logger.info(
f"Triggering Composio connector indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" f"Triggering Composio Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
) )
index_composio_connector_task.delay( index_google_gmail_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to connector_id, search_space_id, str(user.id), indexing_from, indexing_to
) )
response_message = "Composio connector indexing started in the background." response_message = "Composio Gmail indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
from app.tasks.celery_tasks.connector_tasks import (
index_google_calendar_events_task,
)
logger.info(
f"Triggering Composio Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_google_calendar_events_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = (
"Composio Google Calendar indexing started in the background."
)
else: else:
raise HTTPException( raise HTTPException(
@ -1229,6 +1297,48 @@ async def run_slack_indexing(
) )
_AUTH_ERROR_PATTERNS = (
"failed to refresh linear oauth",
"failed to refresh your notion connection",
"failed to refresh notion token",
"authentication failed",
"auth_expired",
"token has been expired or revoked",
"invalid_grant",
)
def _is_auth_error(error_message: str) -> bool:
"""Check if an error message indicates an OAuth token expiry failure."""
if not error_message:
return False
lower = error_message.lower()
return any(pattern in lower for pattern in _AUTH_ERROR_PATTERNS)
async def _persist_auth_expired(session: AsyncSession, connector_id: int) -> None:
"""Flag a connector as auth_expired so the frontend shows a re-auth prompt."""
from sqlalchemy.orm.attributes import flag_modified
try:
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if connector and not connector.config.get("auth_expired"):
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
logger.info(f"Marked connector {connector_id} as auth_expired")
except Exception:
logger.warning(
f"Failed to persist auth_expired for connector {connector_id}",
exc_info=True,
)
async def _run_indexing_with_notifications( async def _run_indexing_with_notifications(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -1433,7 +1543,7 @@ async def _run_indexing_with_notifications(
) )
await ( await (
session.commit() session.commit()
) # Commit to ensure Electric SQL syncs the notification update ) # Commit to ensure Zero syncs the notification update
elif documents_processed > 0: elif documents_processed > 0:
# Update notification to storing stage # Update notification to storing stage
if notification: if notification:
@ -1460,7 +1570,7 @@ async def _run_indexing_with_notifications(
) )
await ( await (
session.commit() session.commit()
) # Commit to ensure Electric SQL syncs the notification update ) # Commit to ensure Zero syncs the notification update
else: else:
# No new documents processed - check if this is an error or just no changes # No new documents processed - check if this is an error or just no changes
if error_or_warning: if error_or_warning:
@ -1486,7 +1596,7 @@ async def _run_indexing_with_notifications(
if is_duplicate_warning or is_empty_result or is_info_warning: if is_duplicate_warning or is_empty_result or is_info_warning:
# These are success cases - sync worked, just found nothing new # These are success cases - sync worked, just found nothing new
logger.info(f"Indexing completed successfully: {error_or_warning}") logger.info(f"Indexing completed successfully: {error_or_warning}")
# Still update timestamp so ElectricSQL syncs and clears "Syncing" UI # Still update timestamp so Zero syncs and clears "Syncing" UI
if update_timestamp_func: if update_timestamp_func:
await update_timestamp_func(session, connector_id) await update_timestamp_func(session, connector_id)
await session.commit() # Commit timestamp update await session.commit() # Commit timestamp update
@ -1509,10 +1619,12 @@ async def _run_indexing_with_notifications(
) )
await ( await (
session.commit() session.commit()
) # Commit to ensure Electric SQL syncs the notification update ) # Commit to ensure Zero syncs the notification update
else: else:
# Actual failure # Actual failure
logger.error(f"Indexing failed: {error_or_warning}") logger.error(f"Indexing failed: {error_or_warning}")
if _is_auth_error(str(error_or_warning)):
await _persist_auth_expired(session, connector_id)
if notification: if notification:
# Refresh notification to ensure it's not stale after indexing function commits # Refresh notification to ensure it's not stale after indexing function commits
await session.refresh(notification) await session.refresh(notification)
@ -1525,13 +1637,13 @@ async def _run_indexing_with_notifications(
) )
await ( await (
session.commit() session.commit()
) # Commit to ensure Electric SQL syncs the notification update ) # Commit to ensure Zero syncs the notification update
else: else:
# Success - just no new documents to index (all skipped/unchanged) # Success - just no new documents to index (all skipped/unchanged)
logger.info( logger.info(
"Indexing completed: No new documents to process (all up to date)" "Indexing completed: No new documents to process (all up to date)"
) )
# Still update timestamp so ElectricSQL syncs and clears "Syncing" UI # Still update timestamp so Zero syncs and clears "Syncing" UI
if update_timestamp_func: if update_timestamp_func:
await update_timestamp_func(session, connector_id) await update_timestamp_func(session, connector_id)
await session.commit() # Commit timestamp update await session.commit() # Commit timestamp update
@ -1547,7 +1659,7 @@ async def _run_indexing_with_notifications(
) )
await ( await (
session.commit() session.commit()
) # Commit to ensure Electric SQL syncs the notification update ) # Commit to ensure Zero syncs the notification update
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded:
# Celery soft time limit was reached - task is about to be killed # Celery soft time limit was reached - task is about to be killed
# Gracefully save progress and mark as interrupted # Gracefully save progress and mark as interrupted
@ -1577,6 +1689,9 @@ async def _run_indexing_with_notifications(
except Exception as e: except Exception as e:
logger.error(f"Error in indexing task: {e!s}", exc_info=True) logger.error(f"Error in indexing task: {e!s}", exc_info=True)
if _is_auth_error(str(e)):
await _persist_auth_expired(session, connector_id)
# Update notification on exception # Update notification on exception
if notification: if notification:
try: try:
@ -2172,10 +2287,9 @@ async def run_google_gmail_indexing(
end_date: str | None, end_date: str | None,
update_last_indexed: bool, update_last_indexed: bool,
on_heartbeat_callback=None, on_heartbeat_callback=None,
) -> tuple[int, str | None]: ) -> tuple[int, int, str | None]:
# Use a reasonable default for max_messages
max_messages = 1000 max_messages = 1000
indexed_count, error_message = await index_google_gmail_messages( indexed_count, skipped_count, error_message = await index_google_gmail_messages(
session=session, session=session,
connector_id=connector_id, connector_id=connector_id,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -2186,8 +2300,7 @@ async def run_google_gmail_indexing(
max_messages=max_messages, max_messages=max_messages,
on_heartbeat_callback=on_heartbeat_callback, on_heartbeat_callback=on_heartbeat_callback,
) )
# index_google_gmail_messages returns (int, str) but we need (int, str | None) return indexed_count, skipped_count, error_message if error_message else None
return indexed_count, error_message if error_message else None
await _run_indexing_with_notifications( await _run_indexing_with_notifications(
session=session, session=session,
@ -2223,6 +2336,7 @@ async def run_google_drive_indexing(
items = GoogleDriveIndexRequest(**items_dict) items = GoogleDriveIndexRequest(**items_dict)
indexing_options = items.indexing_options indexing_options = items.indexing_options
total_indexed = 0 total_indexed = 0
total_skipped = 0
errors = [] errors = []
# Get connector info for notification # Get connector info for notification
@ -2260,7 +2374,11 @@ async def run_google_drive_indexing(
# Index each folder with indexing options # Index each folder with indexing options
for folder in items.folders: for folder in items.folders:
try: try:
indexed_count, error_message = await index_google_drive_files( (
indexed_count,
skipped_count,
error_message,
) = await index_google_drive_files(
session, session,
connector_id, connector_id,
search_space_id, search_space_id,
@ -2272,6 +2390,7 @@ async def run_google_drive_indexing(
max_files=indexing_options.max_files_per_folder, max_files=indexing_options.max_files_per_folder,
include_subfolders=indexing_options.include_subfolders, include_subfolders=indexing_options.include_subfolders,
) )
total_skipped += skipped_count
if error_message: if error_message:
errors.append(f"Folder '{folder.name}': {error_message}") errors.append(f"Folder '{folder.name}': {error_message}")
else: else:
@ -2312,9 +2431,15 @@ async def run_google_drive_indexing(
logger.error( logger.error(
f"Google Drive indexing completed with errors for connector {connector_id}: {error_message}" f"Google Drive indexing completed with errors for connector {connector_id}: {error_message}"
) )
if _is_auth_error(error_message):
await _persist_auth_expired(session, connector_id)
error_message = (
"Google Drive authentication expired. Please re-authenticate."
)
else: else:
# Update notification to storing stage # Update notification to storing stage
if notification: if notification:
await session.refresh(notification)
await NotificationService.connector_indexing.notify_indexing_progress( await NotificationService.connector_indexing.notify_indexing_progress(
session=session, session=session,
notification=notification, notification=notification,
@ -2338,6 +2463,7 @@ async def run_google_drive_indexing(
notification=notification, notification=notification,
indexed_count=total_indexed, indexed_count=total_indexed,
error_message=error_message, error_message=error_message,
skipped_count=total_skipped,
) )
except Exception as e: except Exception as e:
@ -2650,7 +2776,7 @@ async def run_composio_indexing(
Run Composio connector indexing with real-time notifications. Run Composio connector indexing with real-time notifications.
This wraps the Composio indexer with the notification system so that This wraps the Composio indexer with the notification system so that
Electric SQL can sync indexing progress to the frontend in real-time. Zero can sync indexing progress to the frontend in real-time.
Args: Args:
session: Database session session: Database session
@ -2715,9 +2841,14 @@ async def create_mcp_connector(
"You don't have permission to create connectors in this search space", "You don't have permission to create connectors in this search space",
) )
# Ensure unique name across MCP connectors in this search space
unique_name = await ensure_unique_connector_name(
session, connector_data.name, search_space_id, user.id
)
# Create the connector with single server config # Create the connector with single server config
db_connector = SearchSourceConnector( db_connector = SearchSourceConnector(
name=connector_data.name, name=unique_name,
connector_type=SearchSourceConnectorType.MCP_CONNECTOR, connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
is_indexable=False, # MCP connectors are not indexable is_indexable=False, # MCP connectors are not indexable
config={"server_config": connector_data.server_config.model_dump()}, config={"server_config": connector_data.server_config.model_dump()},
@ -3136,6 +3267,12 @@ async def get_drive_picker_token(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to get Drive picker token: {e!s}", exc_info=True) logger.error(f"Failed to get Drive picker token: {e!s}", exc_info=True)
if _is_auth_error(str(e)):
await _persist_auth_expired(session, connector_id)
raise HTTPException(
status_code=400,
detail="Google Drive authentication expired. Please re-authenticate.",
) from e
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Failed to retrieve access token. Check server logs for details.", detail="Failed to retrieve access token. Check server logs for details.",

View file

@ -166,7 +166,7 @@ async def slack_callback(
# Redirect to frontend with error parameter # Redirect to frontend with error parameter
if space_id: if space_id:
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=slack_oauth_denied" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=slack_oauth_denied"
) )
else: else:
return RedirectResponse( return RedirectResponse(
@ -296,7 +296,7 @@ async def slack_callback(
f"Duplicate Slack connector detected for user {user_id} with workspace {connector_identifier}" f"Duplicate Slack connector detected for user {user_id} with workspace {connector_identifier}"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=slack-connector" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=slack-connector"
) )
# Generate a unique, user-friendly connector name # Generate a unique, user-friendly connector name
@ -328,7 +328,7 @@ async def slack_callback(
# Redirect to the frontend with success params # Redirect to the frontend with success params
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=slack-connector&connectorId={new_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=slack-connector&connectorId={new_connector.id}"
) )
except ValidationError as e: except ValidationError as e:

View file

@ -456,7 +456,7 @@ async def create_comment(
thread = message.thread thread = message.thread
comment = ChatComment( comment = ChatComment(
message_id=message_id, message_id=message_id,
thread_id=thread.id, # Denormalized for efficient Electric subscriptions thread_id=thread.id, # Denormalized for efficient per-thread sync
author_id=user.id, author_id=user.id,
content=content, content=content,
) )
@ -569,7 +569,7 @@ async def create_reply(
thread = parent_comment.message.thread thread = parent_comment.message.thread
reply = ChatComment( reply = ChatComment(
message_id=parent_comment.message_id, message_id=parent_comment.message_id,
thread_id=thread.id, # Denormalized for efficient Electric subscriptions thread_id=thread.id, # Denormalized for efficient per-thread sync
parent_id=comment_id, parent_id=comment_id,
author_id=user.id, author_id=user.id,
content=content, content=content,

View file

@ -36,32 +36,14 @@ TOOLKIT_TO_CONNECTOR_TYPE = {
} }
# Mapping of toolkit IDs to document types # Mapping of toolkit IDs to document types
TOOLKIT_TO_DOCUMENT_TYPE = { # Google Drive, Gmail, Calendar use unified native indexers - not in this registry
"googledrive": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR", TOOLKIT_TO_DOCUMENT_TYPE: dict[str, str] = {}
"gmail": "COMPOSIO_GMAIL_CONNECTOR",
"googlecalendar": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
}
# Mapping of toolkit IDs to their indexer functions # Mapping of toolkit IDs to their indexer functions
# Format: toolkit_id -> (module_path, function_name, supports_date_filter) # Format: toolkit_id -> (module_path, function_name, supports_date_filter)
# supports_date_filter: True if the indexer accepts start_date/end_date params # supports_date_filter: True if the indexer accepts start_date/end_date params
TOOLKIT_TO_INDEXER = { # Google Drive, Gmail, Calendar use unified native indexers - not in this registry
"googledrive": ( TOOLKIT_TO_INDEXER: dict[str, tuple[str, str, bool]] = {}
"app.connectors.composio_google_drive_connector",
"index_composio_google_drive",
False, # Google Drive doesn't use date filtering
),
"gmail": (
"app.connectors.composio_gmail_connector",
"index_composio_gmail",
True, # Gmail uses date filtering
),
"googlecalendar": (
"app.connectors.composio_google_calendar_connector",
"index_composio_google_calendar",
True, # Calendar uses date filtering
),
}
class ComposioService: class ComposioService:
@ -247,6 +229,68 @@ class ComposioService:
) )
return False return False
def refresh_connected_account(
self,
connected_account_id: str,
redirect_url: str | None = None,
) -> dict[str, Any]:
"""
Refresh an expired Composio connected account.
For OAuth flows this returns a redirect_url the user must visit to
re-authorise. The same connected_account_id stays valid afterwards.
Args:
connected_account_id: The Composio connected account nanoid.
redirect_url: Where Composio should redirect after re-auth.
Returns:
Dict with id, status, and redirect_url (None when no redirect needed).
"""
kwargs: dict[str, Any] = {}
if redirect_url is not None:
kwargs["body_redirect_url"] = redirect_url
result = self.client.connected_accounts.refresh(
nanoid=connected_account_id,
**kwargs,
)
return {
"id": result.id,
"status": result.status,
"redirect_url": result.redirect_url,
}
def wait_for_connection(
self,
connected_account_id: str,
timeout: float = 30.0,
) -> str:
"""
Poll Composio until the connected account reaches ACTIVE status.
Must be called after refresh() / initiate() to ensure Composio has
finished exchanging the authorization code for valid tokens.
Returns:
The final account status string (should be "ACTIVE").
Raises:
TimeoutError: If the account does not become ACTIVE within *timeout*.
"""
try:
account = self.client.connected_accounts.wait_for_connection(
id=connected_account_id,
timeout=timeout,
)
status = getattr(account, "status", "UNKNOWN")
logger.info(f"Composio account {connected_account_id} is now {status}")
return status
except Exception as e:
logger.error(
f"Timeout/error waiting for Composio account {connected_account_id}: {e!s}"
)
raise
def get_access_token(self, connected_account_id: str) -> str: def get_access_token(self, connected_account_id: str) -> str:
"""Retrieve the raw OAuth access token for a Composio connected account.""" """Retrieve the raw OAuth access token for a Composio connected account."""
account = self.client.connected_accounts.get(nanoid=connected_account_id) account = self.client.connected_accounts.get(nanoid=connected_account_id)
@ -258,6 +302,12 @@ class ComposioService:
access_token = getattr(token, "access_token", None) access_token = getattr(token, "access_token", None)
if not access_token: if not access_token:
raise ValueError(f"No access_token in state.val for {connected_account_id}") raise ValueError(f"No access_token in state.val for {connected_account_id}")
if len(access_token) < 20:
raise ValueError(
f"Composio returned a masked access_token ({len(access_token)} chars) "
f"for account {connected_account_id}. Disable 'Mask Connected Account "
f"Secrets' in Composio dashboard: Settings → Project Settings."
)
return access_token return access_token
async def execute_tool( async def execute_tool(

View file

@ -0,0 +1,13 @@
from app.services.confluence.kb_sync_service import ConfluenceKBSyncService
from app.services.confluence.tool_metadata_service import (
ConfluencePage,
ConfluenceToolMetadataService,
ConfluenceWorkspace,
)
__all__ = [
"ConfluenceKBSyncService",
"ConfluencePage",
"ConfluenceToolMetadataService",
"ConfluenceWorkspace",
]

View file

@ -0,0 +1,240 @@
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
class ConfluenceKBSyncService:
"""Syncs Confluence page documents to the knowledge base after HITL actions."""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
page_id: str,
page_title: str,
space_id: str,
body_content: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
return {"status": "success"}
indexable_content = (body_content or "").strip()
if not indexable_content:
indexable_content = f"Confluence Page: {page_title}"
page_content = f"# {page_title}\n\n{indexable_content}"
content_hash = generate_content_hash(page_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"page_title": page_title,
"space_id": space_id,
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
page_content, user_llm, doc_metadata_for_summary
)
else:
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(page_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=page_title,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
document_metadata={
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": 0,
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, page=%s",
document.id,
page_title,
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for page %s: %s",
page_title,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def sync_after_update(
self,
document_id: int,
page_id: str,
user_id: str,
search_space_id: int,
) -> dict:
from app.tasks.connector_indexers.base import (
get_current_timestamp,
safe_set_chunks,
)
try:
document = await self.db_session.get(Document, document_id)
if not document:
return {"status": "not_indexed"}
connector_id = document.connector_id
if not connector_id:
return {"status": "error", "message": "Document has no connector_id"}
client = ConfluenceHistoryConnector(
session=self.db_session, connector_id=connector_id
)
page_data = await client.get_page(page_id)
await client.close()
page_title = page_data.get("title", "")
body_obj = page_data.get("body", {})
body_content = ""
if isinstance(body_obj, dict):
storage = body_obj.get("storage", {})
if isinstance(storage, dict):
body_content = storage.get("value", "")
page_content = f"# {page_title}\n\n{body_content}"
if not page_content.strip():
return {"status": "error", "message": "Page produced empty content"}
space_id = (document.document_metadata or {}).get("space_id", "")
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)
if user_llm:
doc_meta = {
"page_title": page_title,
"space_id": space_id,
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
summary_content, summary_embedding = await generate_document_summary(
page_content, user_llm, doc_meta
)
else:
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(page_content)
document.title = page_title
document.content = summary_content
document.content_hash = generate_content_hash(page_content, search_space_id)
document.embedding = summary_embedding
from sqlalchemy.orm.attributes import flag_modified
document.document_metadata = {
**(document.document_metadata or {}),
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
flag_modified(document, "document_metadata")
await safe_set_chunks(self.db_session, document, chunks)
document.updated_at = get_current_timestamp()
await self.db_session.commit()
logger.info(
"KB sync successful for document %s (%s)",
document_id,
page_title,
)
return {"status": "success"}
except Exception as e:
logger.error(
"KB sync failed for document %s: %s", document_id, e, exc_info=True
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}

View file

@ -0,0 +1,314 @@
import logging
from dataclasses import dataclass
from sqlalchemy import and_, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
logger = logging.getLogger(__name__)
@dataclass
class ConfluenceWorkspace:
"""Represents a Confluence connector as a workspace for tool context."""
id: int
name: str
base_url: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "ConfluenceWorkspace":
return cls(
id=connector.id,
name=connector.name,
base_url=connector.config.get("base_url", ""),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"base_url": self.base_url,
}
@dataclass
class ConfluencePage:
"""Represents an indexed Confluence page resolved from the knowledge base."""
page_id: str
page_title: str
space_id: str
connector_id: int
document_id: int
indexed_at: str | None
@classmethod
def from_document(cls, document: Document) -> "ConfluencePage":
meta = document.document_metadata or {}
return cls(
page_id=meta.get("page_id", ""),
page_title=meta.get("page_title", document.title),
space_id=meta.get("space_id", ""),
connector_id=document.connector_id,
document_id=document.id,
indexed_at=meta.get("indexed_at"),
)
def to_dict(self) -> dict:
return {
"page_id": self.page_id,
"page_title": self.page_title,
"space_id": self.space_id,
"connector_id": self.connector_id,
"document_id": self.document_id,
"indexed_at": self.indexed_at,
}
class ConfluenceToolMetadataService:
"""Builds interrupt context for Confluence HITL tools."""
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
"""Check if the Confluence connector auth is still valid.
Returns True if auth is expired/invalid, False if healthy.
"""
try:
client = ConfluenceHistoryConnector(
session=self._db_session, connector_id=connector.id
)
await client._get_valid_token()
await client.close()
return False
except Exception as e:
logger.warning(
"Confluence connector %s health check failed: %s", connector.id, e
)
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await self._db_session.commit()
await self._db_session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return True
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
"""Return context needed to create a new Confluence page.
Fetches all connected accounts, and for the first healthy one fetches spaces.
"""
connectors = await self._get_all_confluence_connectors(search_space_id, user_id)
if not connectors:
return {"error": "No Confluence account connected"}
accounts = []
spaces = []
fetched_context = False
for connector in connectors:
auth_expired = await self._check_account_health(connector)
workspace = ConfluenceWorkspace.from_connector(connector)
accounts.append(
{
**workspace.to_dict(),
"auth_expired": auth_expired,
}
)
if not auth_expired and not fetched_context:
try:
client = ConfluenceHistoryConnector(
session=self._db_session, connector_id=connector.id
)
raw_spaces = await client.get_all_spaces()
spaces = [
{"id": s.get("id"), "key": s.get("key"), "name": s.get("name")}
for s in raw_spaces
]
await client.close()
fetched_context = True
except Exception as e:
logger.warning(
"Failed to fetch Confluence spaces for connector %s: %s",
connector.id,
e,
)
return {
"accounts": accounts,
"spaces": spaces,
}
async def get_update_context(
self, search_space_id: int, user_id: str, page_ref: str
) -> dict:
"""Return context needed to update an indexed Confluence page.
Resolves the page from KB, then fetches current content and version from API.
"""
document = await self._resolve_page(search_space_id, user_id, page_ref)
if not document:
return {
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
"Please make sure the page is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = await self._check_account_health(connector)
if auth_expired:
return {
"error": "Confluence authentication has expired. Please re-authenticate.",
"auth_expired": True,
"connector_id": connector.id,
}
workspace = ConfluenceWorkspace.from_connector(connector)
page = ConfluencePage.from_document(document)
try:
client = ConfluenceHistoryConnector(
session=self._db_session, connector_id=connector.id
)
page_data = await client.get_page(page.page_id)
await client.close()
except Exception as e:
error_str = str(e).lower()
if (
"401" in error_str
or "403" in error_str
or "authentication" in error_str
):
return {
"error": f"Failed to fetch Confluence page: {e!s}",
"auth_expired": True,
"connector_id": connector.id,
}
return {"error": f"Failed to fetch Confluence page: {e!s}"}
body_storage = ""
body_obj = page_data.get("body", {})
if isinstance(body_obj, dict):
storage = body_obj.get("storage", {})
if isinstance(storage, dict):
body_storage = storage.get("value", "")
version_obj = page_data.get("version", {})
version_number = (
version_obj.get("number", 1) if isinstance(version_obj, dict) else 1
)
return {
"account": {**workspace.to_dict(), "auth_expired": False},
"page": {
"page_id": page.page_id,
"page_title": page_data.get("title", page.page_title),
"space_id": page.space_id,
"body": body_storage,
"version": version_number,
"document_id": page.document_id,
"indexed_at": page.indexed_at,
},
}
async def get_deletion_context(
self, search_space_id: int, user_id: str, page_ref: str
) -> dict:
"""Return context needed to delete a Confluence page (KB metadata only)."""
document = await self._resolve_page(search_space_id, user_id, page_ref)
if not document:
return {
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
"Please make sure the page is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = connector.config.get("auth_expired", False)
workspace = ConfluenceWorkspace.from_connector(connector)
page = ConfluencePage.from_document(document)
return {
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
"page": page.to_dict(),
}
async def _resolve_page(
self, search_space_id: int, user_id: str, page_ref: str
) -> Document | None:
"""Resolve a page from KB: page_title -> document.title."""
ref_lower = page_ref.lower()
result = await self._db_session.execute(
select(Document)
.join(
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.CONFLUENCE_CONNECTOR,
SearchSourceConnector.user_id == user_id,
or_(
func.lower(Document.document_metadata.op("->>")("page_title"))
== ref_lower,
func.lower(Document.title) == ref_lower,
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
return result.scalars().first()
async def _get_all_confluence_connectors(
self, search_space_id: int, user_id: str
) -> list[SearchSourceConnector]:
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
)
return result.scalars().all()
async def _get_connector_for_document(
self, document: Document, user_id: str
) -> SearchSourceConnector | None:
if not document.connector_id:
return None
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.user_id == user_id,
)
)
)
return result.scalars().first()

View file

@ -11,6 +11,7 @@ from sqlalchemy.future import select
from tavily import TavilyClient from tavily import TavilyClient
from app.db import ( from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk, Chunk,
Document, Document,
SearchSourceConnector, SearchSourceConnector,
@ -219,7 +220,7 @@ class ConnectorService:
self, self,
query_text: str, query_text: str,
search_space_id: int, search_space_id: int,
document_type: str, document_type: str | list[str],
top_k: int = 20, top_k: int = 20,
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
@ -241,7 +242,8 @@ class ConnectorService:
Args: Args:
query_text: The search query text query_text: The search query text
search_space_id: The search space ID to search within search_space_id: The search space ID to search within
document_type: Document type to filter (e.g., "FILE", "CRAWLED_URL") document_type: Document type(s) to filter (e.g., "FILE", "CRAWLED_URL",
or a list for multi-type queries)
top_k: Number of results to return top_k: Number of results to return
start_date: Optional start date for filtering documents by updated_at start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at
@ -254,6 +256,16 @@ class ConnectorService:
perf = get_perf_logger() perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()
# Expand native Google types to include legacy Composio equivalents
# so old documents remain searchable until re-indexed.
if isinstance(document_type, str) and document_type in NATIVE_TO_LEGACY_DOCTYPE:
resolved_type: str | list[str] = [
document_type,
NATIVE_TO_LEGACY_DOCTYPE[document_type],
]
else:
resolved_type = document_type
# RRF constant # RRF constant
k = 60 k = 60
@ -276,7 +288,7 @@ class ConnectorService:
"query_text": query_text, "query_text": query_text,
"top_k": retriever_top_k, "top_k": retriever_top_k,
"search_space_id": search_space_id, "search_space_id": search_space_id,
"document_type": document_type, "document_type": resolved_type,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"query_embedding": query_embedding, "query_embedding": query_embedding,
@ -2746,299 +2758,6 @@ class ConnectorService:
return result_object, obsidian_docs return result_object, obsidian_docs
# =========================================================================
# Composio Connector Search Methods
# =========================================================================
async def search_composio_google_drive(
self,
user_query: str,
search_space_id: int,
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> tuple:
"""
Search for Composio Google Drive files and return both the source information
and langchain documents.
Uses combined chunk-level and document-level hybrid search with RRF fusion.
Args:
user_query: The user's query
search_space_id: The search space ID to search in
top_k: Maximum number of results to return
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
Returns:
tuple: (sources_info, langchain_documents)
"""
composio_drive_docs = await self._combined_rrf_search(
query_text=user_query,
search_space_id=search_space_id,
document_type="COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
top_k=top_k,
start_date=start_date,
end_date=end_date,
)
# Early return if no results
if not composio_drive_docs:
return {
"id": 54,
"name": "Google Drive (Composio)",
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
"sources": [],
}, []
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return (
doc_info.get("title")
or metadata.get("title")
or metadata.get("file_name")
or "Untitled Document"
)
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return metadata.get("url") or metadata.get("web_view_link") or ""
def _description_fn(
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> str:
description = self._chunk_preview(chunk.get("content", ""), limit=200)
info_parts = []
mime_type = metadata.get("mime_type")
modified_time = metadata.get("modified_time")
if mime_type:
info_parts.append(f"Type: {mime_type}")
if modified_time:
info_parts.append(f"Modified: {modified_time}")
if info_parts:
description = (description + " | " + " | ".join(info_parts)).strip(" |")
return description
def _extra_fields_fn(
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> dict[str, Any]:
return {
"mime_type": metadata.get("mime_type", ""),
"file_id": metadata.get("file_id", ""),
"modified_time": metadata.get("modified_time", ""),
}
sources_list = self._build_chunk_sources_from_documents(
composio_drive_docs,
title_fn=_title_fn,
url_fn=_url_fn,
description_fn=_description_fn,
extra_fields_fn=_extra_fields_fn,
)
# Create result object
result_object = {
"id": 54,
"name": "Google Drive (Composio)",
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
"sources": sources_list,
}
return result_object, composio_drive_docs
async def search_composio_gmail(
self,
user_query: str,
search_space_id: int,
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> tuple:
"""
Search for Composio Gmail messages and return both the source information
and langchain documents.
Uses combined chunk-level and document-level hybrid search with RRF fusion.
Args:
user_query: The user's query
search_space_id: The search space ID to search in
top_k: Maximum number of results to return
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
Returns:
tuple: (sources_info, langchain_documents)
"""
composio_gmail_docs = await self._combined_rrf_search(
query_text=user_query,
search_space_id=search_space_id,
document_type="COMPOSIO_GMAIL_CONNECTOR",
top_k=top_k,
start_date=start_date,
end_date=end_date,
)
# Early return if no results
if not composio_gmail_docs:
return {
"id": 55,
"name": "Gmail (Composio)",
"type": "COMPOSIO_GMAIL_CONNECTOR",
"sources": [],
}, []
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return (
doc_info.get("title")
or metadata.get("subject")
or metadata.get("title")
or "Untitled Email"
)
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return metadata.get("url") or ""
def _description_fn(
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> str:
description = self._chunk_preview(chunk.get("content", ""), limit=200)
info_parts = []
sender = metadata.get("from") or metadata.get("sender")
date = metadata.get("date") or metadata.get("received_at")
if sender:
info_parts.append(f"From: {sender}")
if date:
info_parts.append(f"Date: {date}")
if info_parts:
description = (description + " | " + " | ".join(info_parts)).strip(" |")
return description
def _extra_fields_fn(
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> dict[str, Any]:
return {
"message_id": metadata.get("message_id", ""),
"thread_id": metadata.get("thread_id", ""),
"from": metadata.get("from", ""),
"to": metadata.get("to", ""),
"date": metadata.get("date", ""),
}
sources_list = self._build_chunk_sources_from_documents(
composio_gmail_docs,
title_fn=_title_fn,
url_fn=_url_fn,
description_fn=_description_fn,
extra_fields_fn=_extra_fields_fn,
)
# Create result object
result_object = {
"id": 55,
"name": "Gmail (Composio)",
"type": "COMPOSIO_GMAIL_CONNECTOR",
"sources": sources_list,
}
return result_object, composio_gmail_docs
async def search_composio_google_calendar(
self,
user_query: str,
search_space_id: int,
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> tuple:
"""
Search for Composio Google Calendar events and return both the source information
and langchain documents.
Uses combined chunk-level and document-level hybrid search with RRF fusion.
Args:
user_query: The user's query
search_space_id: The search space ID to search in
top_k: Maximum number of results to return
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
Returns:
tuple: (sources_info, langchain_documents)
"""
composio_calendar_docs = await self._combined_rrf_search(
query_text=user_query,
search_space_id=search_space_id,
document_type="COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
top_k=top_k,
start_date=start_date,
end_date=end_date,
)
# Early return if no results
if not composio_calendar_docs:
return {
"id": 56,
"name": "Google Calendar (Composio)",
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
"sources": [],
}, []
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return (
doc_info.get("title")
or metadata.get("summary")
or metadata.get("title")
or "Untitled Event"
)
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
return metadata.get("url") or metadata.get("html_link") or ""
def _description_fn(
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> str:
description = self._chunk_preview(chunk.get("content", ""), limit=200)
info_parts = []
start_time = metadata.get("start_time") or metadata.get("start")
end_time = metadata.get("end_time") or metadata.get("end")
if start_time:
info_parts.append(f"Start: {start_time}")
if end_time:
info_parts.append(f"End: {end_time}")
if info_parts:
description = (description + " | " + " | ".join(info_parts)).strip(" |")
return description
def _extra_fields_fn(
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
) -> dict[str, Any]:
return {
"event_id": metadata.get("event_id", ""),
"calendar_id": metadata.get("calendar_id", ""),
"start_time": metadata.get("start_time", ""),
"end_time": metadata.get("end_time", ""),
"location": metadata.get("location", ""),
}
sources_list = self._build_chunk_sources_from_documents(
composio_calendar_docs,
title_fn=_title_fn,
url_fn=_url_fn,
description_fn=_description_fn,
extra_fields_fn=_extra_fields_fn,
)
# Create result object
result_object = {
"id": 56,
"name": "Google Calendar (Composio)",
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
"sources": sources_list,
}
return result_object, composio_calendar_docs
# ========================================================================= # =========================================================================
# Utility Methods for Connector Discovery # Utility Methods for Connector Discovery
# ========================================================================= # =========================================================================

View file

@ -0,0 +1,13 @@
from app.services.gmail.kb_sync_service import GmailKBSyncService
from app.services.gmail.tool_metadata_service import (
GmailAccount,
GmailMessage,
GmailToolMetadataService,
)
__all__ = [
"GmailAccount",
"GmailKBSyncService",
"GmailMessage",
"GmailToolMetadataService",
]

View file

@ -0,0 +1,169 @@
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
class GmailKBSyncService:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
message_id: str,
thread_id: str,
subject: str,
sender: str,
date_str: str,
body_text: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
draft_id: str | None = None,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Gmail message %s already exists (doc_id=%s), skipping",
message_id,
existing.id,
)
return {"status": "success"}
indexable_content = (
f"Gmail Message: {subject}\n\nFrom: {sender}\nDate: {date_str}\n\n"
f"{body_text or ''}"
).strip()
if not indexable_content:
indexable_content = f"Gmail message: {subject}"
content_hash = generate_content_hash(indexable_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
logger.info(
"Content-hash collision for Gmail message %s -- identical content "
"exists in doc %s. Using unique_identifier_hash as content_hash.",
message_id,
dup.id,
)
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"subject": subject,
"sender": sender,
"document_type": "Gmail Message",
"connector_type": "Gmail",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
indexable_content, user_llm, doc_metadata_for_summary
)
else:
logger.warning("No LLM configured -- using fallback summary")
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
doc_metadata = {
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
"indexed_at": now_str,
}
if draft_id:
doc_metadata["draft_id"] = draft_id
document = Document(
title=subject,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
document_metadata=doc_metadata,
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
source_markdown=body_text,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, subject=%s, chunks=%d",
document.id,
subject,
len(chunks),
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
logger.warning(
"Duplicate constraint hit during KB sync for message %s. "
"Rolling back -- periodic indexer will handle it. Error: %s",
message_id,
e,
)
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for message %s: %s",
message_id,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}

View file

@ -0,0 +1,451 @@
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from sqlalchemy import String, and_, cast, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
@dataclass
class GmailAccount:
id: int
name: str
email: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
email = ""
if connector.name and " - " in connector.name:
email = connector.name.split(" - ", 1)[1]
return cls(id=connector.id, name=connector.name, email=email)
def to_dict(self) -> dict:
return {"id": self.id, "name": self.name, "email": self.email}
@dataclass
class GmailMessage:
message_id: str
thread_id: str
subject: str
sender: str
date: str
connector_id: int
document_id: int
@classmethod
def from_document(cls, document: Document) -> "GmailMessage":
meta = document.document_metadata or {}
return cls(
message_id=meta.get("message_id", ""),
thread_id=meta.get("thread_id", ""),
subject=meta.get("subject", document.title),
sender=meta.get("sender", ""),
date=meta.get("date", ""),
connector_id=document.connector_id,
document_id=document.id,
)
def to_dict(self) -> dict:
return {
"message_id": self.message_id,
"thread_id": self.thread_id,
"subject": self.subject,
"sender": self.sender,
"date": self.date,
"connector_id": self.connector_id,
"document_id": self.document_id,
}
class GmailToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
config_data = dict(connector.config)
from app.config import config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
return Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Gmail connector's credentials are still valid.
Uses a lightweight ``users().getProfile(userId='me')`` call.
Returns True if the credentials are expired/invalid, False if healthy.
"""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
return True
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
None, lambda: service.users().getProfile(userId="me").execute()
)
return False
except Exception as e:
logger.warning(
"Gmail connector %s health check failed: %s",
connector_id,
e,
)
return True
async def _persist_auth_expired(self, connector_id: int) -> None:
"""Persist ``auth_expired: True`` to the connector config if not already set."""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
db_connector = result.scalar_one_or_none()
if db_connector and not db_connector.config.get("auth_expired"):
db_connector.config = {**db_connector.config, "auth_expired": True}
flag_modified(db_connector, "config")
await self._db_session.commit()
await self._db_session.refresh(db_connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector_id,
exc_info=True,
)
async def _get_accounts(
self, search_space_id: int, user_id: str
) -> list[GmailAccount]:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(
[
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
),
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
return [GmailAccount.from_connector(c) for c in connectors]
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_accounts(search_space_id, user_id)
if not accounts:
return {
"accounts": [],
"error": "No Gmail account connected",
}
accounts_with_status = []
for acc in accounts:
acc_dict = acc.to_dict()
auth_expired = await self._check_account_health(acc.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(acc.id)
else:
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == acc.id
)
)
connector = result.scalar_one_or_none()
if connector:
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
profile = await asyncio.get_event_loop().run_in_executor(
None,
lambda service=service: (
service.users().getProfile(userId="me").execute()
),
)
acc_dict["email"] = profile.get("emailAddress", "")
except Exception:
logger.warning(
"Failed to fetch email for Gmail connector %s",
acc.id,
exc_info=True,
)
accounts_with_status.append(acc_dict)
return {"accounts": accounts_with_status}
async def get_update_context(
self, search_space_id: int, user_id: str, email_ref: str
) -> dict:
document, connector = await self._resolve_email(
search_space_id, user_id, email_ref
)
if not document or not connector:
return {
"error": (
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
"This could mean: (1) the draft doesn't exist, "
"(2) it hasn't been indexed yet, "
"or (3) the subject is different. "
"Please check the exact draft subject in Gmail."
)
}
account = GmailAccount.from_connector(connector)
message = GmailMessage.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
result: dict = {
"account": acc_dict,
"email": message.to_dict(),
}
meta = document.document_metadata or {}
if meta.get("draft_id"):
result["draft_id"] = meta["draft_id"]
if not auth_expired:
existing_body = await self._fetch_draft_body(
connector, message.message_id, meta.get("draft_id")
)
if existing_body is not None:
result["existing_body"] = existing_body
return result
async def _fetch_draft_body(
self,
connector: SearchSourceConnector,
message_id: str,
draft_id: str | None,
) -> str | None:
"""Fetch the plain-text body of a Gmail draft via the API.
Tries ``drafts.get`` first (if *draft_id* is available), then falls
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
Returns ``None`` on any failure so callers can degrade gracefully.
"""
try:
creds = await self._build_credentials(connector)
service = build("gmail", "v1", credentials=creds)
if not draft_id:
draft_id = await self._find_draft_id(service, message_id)
if not draft_id:
return None
draft = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.users()
.drafts()
.get(userId="me", id=draft_id, format="full")
.execute()
),
)
payload = draft.get("message", {}).get("payload", {})
return self._extract_body_from_payload(payload)
except Exception:
logger.warning(
"Failed to fetch draft body for message_id=%s",
message_id,
exc_info=True,
)
return None
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
"""Resolve a draft ID from its message ID by scanning drafts.list."""
try:
page_token = None
while True:
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
if page_token:
kwargs["pageToken"] = page_token
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda kwargs=kwargs: (
service.users().drafts().list(**kwargs).execute()
),
)
for draft in response.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft["id"]
page_token = response.get("nextPageToken")
if not page_token:
break
return None
except Exception:
logger.warning(
"Failed to look up draft by message_id=%s", message_id, exc_info=True
)
return None
@staticmethod
def _extract_body_from_payload(payload: dict) -> str | None:
"""Extract the plain-text (or html→text) body from a Gmail payload."""
import base64
def _get_parts(p: dict) -> list[dict]:
if "parts" in p:
parts: list[dict] = []
for sub in p["parts"]:
parts.extend(_get_parts(sub))
return parts
return [p]
parts = _get_parts(payload)
text_content = ""
for part in parts:
mime_type = part.get("mimeType", "")
data = part.get("body", {}).get("data", "")
if mime_type == "text/plain" and data:
text_content += base64.urlsafe_b64decode(data + "===").decode(
"utf-8", errors="ignore"
)
elif mime_type == "text/html" and data and not text_content:
from markdownify import markdownify as md
raw_html = base64.urlsafe_b64decode(data + "===").decode(
"utf-8", errors="ignore"
)
text_content = md(raw_html).strip()
return text_content.strip() if text_content.strip() else None
async def get_trash_context(
self, search_space_id: int, user_id: str, email_ref: str
) -> dict:
document, connector = await self._resolve_email(
search_space_id, user_id, email_ref
)
if not document or not connector:
return {
"error": (
f"Email '{email_ref}' not found in your indexed Gmail messages. "
"This could mean: (1) the email doesn't exist, "
"(2) it hasn't been indexed yet, "
"or (3) the subject is different."
)
}
account = GmailAccount.from_connector(connector)
message = GmailMessage.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"account": acc_dict,
"email": message.to_dict(),
}
async def _resolve_email(
self, search_space_id: int, user_id: str, email_ref: str
) -> tuple[Document | None, SearchSourceConnector | None]:
result = await self._db_session.execute(
select(Document, SearchSourceConnector)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type.in_(
[
DocumentType.GOOGLE_GMAIL_CONNECTOR,
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
]
),
SearchSourceConnector.user_id == user_id,
or_(
func.lower(cast(Document.document_metadata["subject"], String))
== func.lower(email_ref),
func.lower(Document.title) == func.lower(email_ref),
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
row = result.first()
if row:
return row[0], row[1]
return None, None

View file

@ -0,0 +1,13 @@
from app.services.google_calendar.kb_sync_service import GoogleCalendarKBSyncService
from app.services.google_calendar.tool_metadata_service import (
GoogleCalendarAccount,
GoogleCalendarEvent,
GoogleCalendarToolMetadataService,
)
__all__ = [
"GoogleCalendarAccount",
"GoogleCalendarEvent",
"GoogleCalendarKBSyncService",
"GoogleCalendarToolMetadataService",
]

View file

@ -0,0 +1,374 @@
import asyncio
import logging
from datetime import datetime
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
class GoogleCalendarKBSyncService:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
event_id: str,
event_summary: str,
calendar_id: str,
start_time: str,
end_time: str,
location: str | None,
html_link: str | None,
description: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Calendar event %s already exists (doc_id=%s), skipping",
event_id,
existing.id,
)
return {"status": "success"}
indexable_content = (
f"Google Calendar Event: {event_summary}\n\n"
f"Start: {start_time}\n"
f"End: {end_time}\n"
f"Location: {location or 'N/A'}\n\n"
f"{description or ''}"
).strip()
content_hash = generate_content_hash(indexable_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
logger.info(
"Content-hash collision for Calendar event %s -- identical content "
"exists in doc %s. Using unique_identifier_hash as content_hash.",
event_id,
dup.id,
)
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"event_summary": event_summary,
"start_time": start_time,
"end_time": end_time,
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
indexable_content, user_llm, doc_metadata_for_summary
)
else:
logger.warning("No LLM configured -- using fallback summary")
summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=event_summary,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
document_metadata={
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"html_link": html_link,
"source_connector": "google_calendar",
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
source_markdown=indexable_content,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, event=%s, chunks=%d",
document.id,
event_summary,
len(chunks),
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
logger.warning(
"Duplicate constraint hit during KB sync for event %s. "
"Rolling back -- periodic indexer will handle it. Error: %s",
event_id,
e,
)
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for event %s: %s",
event_id,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def sync_after_update(
self,
document_id: int,
event_id: str,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
get_current_timestamp,
safe_set_chunks,
)
try:
document = await self.db_session.get(Document, document_id)
if not document:
logger.warning("Document %s not found in KB", document_id)
return {"status": "not_indexed"}
creds = await self._build_credentials_for_connector(connector_id)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
calendar_id = (document.document_metadata or {}).get(
"calendar_id", "primary"
)
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event_id)
.execute()
),
)
event_summary = live_event.get("summary", "")
description = live_event.get("description", "")
location = live_event.get("location", "")
start_data = live_event.get("start", {})
start_time = start_data.get("dateTime", start_data.get("date", ""))
end_data = live_event.get("end", {})
end_time = end_data.get("dateTime", end_data.get("date", ""))
attendees = [
{
"email": a.get("email", ""),
"responseStatus": a.get("responseStatus", ""),
}
for a in live_event.get("attendees", [])
]
indexable_content = (
f"Google Calendar Event: {event_summary}\n\n"
f"Start: {start_time}\n"
f"End: {end_time}\n"
f"Location: {location or 'N/A'}\n\n"
f"{description or ''}"
).strip()
if not indexable_content:
return {"status": "error", "message": "Event produced empty content"}
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)
doc_metadata_for_summary = {
"event_summary": event_summary,
"start_time": start_time,
"end_time": end_time,
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
indexable_content, user_llm, doc_metadata_for_summary
)
else:
summary_content = (
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document.title = event_summary
document.content = summary_content
document.content_hash = generate_content_hash(
indexable_content, search_space_id
)
document.embedding = summary_embedding
document.document_metadata = {
**(document.document_metadata or {}),
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
"attendees": attendees,
"html_link": live_event.get("htmlLink", ""),
"indexed_at": now_str,
"connector_id": connector_id,
}
flag_modified(document, "document_metadata")
await safe_set_chunks(self.db_session, document, chunks)
document.updated_at = get_current_timestamp()
await self.db_session.commit()
logger.info(
"KB sync after update succeeded for document %s (event: %s)",
document_id,
event_summary,
)
return {"status": "success"}
except Exception as e:
logger.error(
"KB sync after update failed for document %s: %s",
document_id,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
result = await self.db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
raise ValueError(f"Connector {connector_id} not found")
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
raise ValueError("Composio connected_account_id not found")
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
return Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)

View file

@ -0,0 +1,431 @@
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from sqlalchemy import String, and_, cast, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
CALENDAR_CONNECTOR_TYPES = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
CALENDAR_DOCUMENT_TYPES = [
DocumentType.GOOGLE_CALENDAR_CONNECTOR,
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
@dataclass
class GoogleCalendarAccount:
id: int
name: str
@classmethod
def from_connector(
cls, connector: SearchSourceConnector
) -> "GoogleCalendarAccount":
return cls(id=connector.id, name=connector.name)
def to_dict(self) -> dict:
return {"id": self.id, "name": self.name}
@dataclass
class GoogleCalendarEvent:
event_id: str
summary: str
start: str
end: str
description: str
location: str
attendees: list
calendar_id: str
document_id: int
indexed_at: str | None
@classmethod
def from_document(cls, document: Document) -> "GoogleCalendarEvent":
meta = document.document_metadata or {}
return cls(
event_id=meta.get("event_id", ""),
summary=meta.get("event_summary", document.title),
start=meta.get("start_time", ""),
end=meta.get("end_time", ""),
description=meta.get("description", ""),
location=meta.get("location", ""),
attendees=meta.get("attendees", []),
calendar_id=meta.get("calendar_id", "primary"),
document_id=document.id,
indexed_at=meta.get("indexed_at"),
)
def to_dict(self) -> dict:
return {
"event_id": self.event_id,
"summary": self.summary,
"start": self.start,
"end": self.end,
"description": self.description,
"location": self.location,
"attendees": self.attendees,
"calendar_id": self.calendar_id,
"document_id": self.document_id,
"indexed_at": self.indexed_at,
}
class GoogleCalendarToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
return build_composio_credentials(cca_id)
raise ValueError("Composio connected_account_id not found")
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
if config_data.get("client_secret"):
config_data["client_secret"] = token_encryption.decrypt_token(
config_data["client_secret"]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
return Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Google Calendar connector's credentials are still valid.
Uses a lightweight calendarList.list(maxResults=1) call to verify access.
Returns True if the credentials are expired/invalid, False if healthy.
"""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
return True
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: (
build("calendar", "v3", credentials=creds)
.calendarList()
.list(maxResults=1)
.execute()
),
)
return False
except Exception as e:
logger.warning(
"Google Calendar connector %s health check failed: %s",
connector_id,
e,
)
return True
async def _persist_auth_expired(self, connector_id: int) -> None:
"""Persist ``auth_expired: True`` to the connector config if not already set."""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
db_connector = result.scalar_one_or_none()
if db_connector and not db_connector.config.get("auth_expired"):
db_connector.config = {**db_connector.config, "auth_expired": True}
flag_modified(db_connector, "config")
await self._db_session.commit()
await self._db_session.refresh(db_connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector_id,
exc_info=True,
)
async def _get_accounts(
self, search_space_id: int, user_id: str
) -> list[GoogleCalendarAccount]:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
return [GoogleCalendarAccount.from_connector(c) for c in connectors]
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_accounts(search_space_id, user_id)
if not accounts:
return {
"accounts": [],
"error": "No Google Calendar account connected",
}
accounts_with_status = []
for acc in accounts:
acc_dict = acc.to_dict()
auth_expired = await self._check_account_health(acc.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(acc.id)
accounts_with_status.append(acc_dict)
healthy_account = next(
(a for a in accounts_with_status if not a.get("auth_expired")), None
)
if not healthy_account:
return {
"accounts": accounts_with_status,
"calendars": [],
"timezone": "",
"error": "All connected Google Calendar accounts have expired credentials",
}
connector_id = healthy_account["id"]
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
calendars = []
timezone_str = ""
if connector:
try:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
cal_list = await loop.run_in_executor(
None, lambda: service.calendarList().list().execute()
)
for cal in cal_list.get("items", []):
calendars.append(
{
"id": cal.get("id", ""),
"summary": cal.get("summary", ""),
"primary": cal.get("primary", False),
}
)
tz_setting = await loop.run_in_executor(
None,
lambda: service.settings().get(setting="timezone").execute(),
)
timezone_str = tz_setting.get("value", "")
except Exception:
logger.warning(
"Failed to fetch calendars/timezone for connector %s",
connector_id,
exc_info=True,
)
return {
"accounts": accounts_with_status,
"calendars": calendars,
"timezone": timezone_str,
}
async def get_update_context(
self, search_space_id: int, user_id: str, event_ref: str
) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved:
return {
"error": (
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the event name is different."
)
}
document, connector = resolved
account = GoogleCalendarAccount.from_connector(connector)
event = GoogleCalendarEvent.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"error": "Google Calendar credentials have expired. Please re-authenticate.",
"auth_expired": True,
"connector_id": connector.id,
}
event_dict = event.to_dict()
try:
creds = await self._build_credentials(connector)
loop = asyncio.get_event_loop()
service = await loop.run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
calendar_id = event.calendar_id or "primary"
live_event = await loop.run_in_executor(
None,
lambda: (
service.events()
.get(calendarId=calendar_id, eventId=event.event_id)
.execute()
),
)
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
event_dict["description"] = live_event.get(
"description", event_dict["description"]
)
event_dict["location"] = live_event.get("location", event_dict["location"])
start_data = live_event.get("start", {})
event_dict["start"] = start_data.get(
"dateTime", start_data.get("date", event_dict["start"])
)
end_data = live_event.get("end", {})
event_dict["end"] = end_data.get(
"dateTime", end_data.get("date", event_dict["end"])
)
event_dict["attendees"] = [
{
"email": a.get("email", ""),
"responseStatus": a.get("responseStatus", ""),
}
for a in live_event.get("attendees", [])
]
except Exception:
logger.warning(
"Failed to fetch live event data for event %s, using KB metadata",
event.event_id,
exc_info=True,
)
return {
"account": acc_dict,
"event": event_dict,
}
async def get_deletion_context(
self, search_space_id: int, user_id: str, event_ref: str
) -> dict:
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
if not resolved:
return {
"error": (
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the event name is different."
)
}
document, connector = resolved
account = GoogleCalendarAccount.from_connector(connector)
event = GoogleCalendarEvent.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return {
"account": acc_dict,
"event": event.to_dict(),
}
async def _resolve_event(
self, search_space_id: int, user_id: str, event_ref: str
) -> tuple[Document, SearchSourceConnector] | None:
result = await self._db_session.execute(
select(Document, SearchSourceConnector)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type.in_(CALENDAR_DOCUMENT_TYPES),
SearchSourceConnector.user_id == user_id,
or_(
func.lower(
cast(Document.document_metadata["event_summary"], String)
)
== func.lower(event_ref),
func.lower(Document.title) == func.lower(event_ref),
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
row = result.first()
if row:
return row[0], row[1]
return None

View file

@ -1,3 +1,4 @@
from app.services.google_drive.kb_sync_service import GoogleDriveKBSyncService
from app.services.google_drive.tool_metadata_service import ( from app.services.google_drive.tool_metadata_service import (
GoogleDriveAccount, GoogleDriveAccount,
GoogleDriveFile, GoogleDriveFile,
@ -7,5 +8,6 @@ from app.services.google_drive.tool_metadata_service import (
__all__ = [ __all__ = [
"GoogleDriveAccount", "GoogleDriveAccount",
"GoogleDriveFile", "GoogleDriveFile",
"GoogleDriveKBSyncService",
"GoogleDriveToolMetadataService", "GoogleDriveToolMetadataService",
] ]

View file

@ -0,0 +1,164 @@
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
class GoogleDriveKBSyncService:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
file_id: str,
file_name: str,
mime_type: str,
web_view_link: str | None,
content: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Drive file %s already exists (doc_id=%s), skipping",
file_id,
existing.id,
)
return {"status": "success"}
indexable_content = (content or "").strip()
if not indexable_content:
indexable_content = (
f"Google Drive file: {file_name} (type: {mime_type})"
)
content_hash = generate_content_hash(indexable_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
logger.info(
"Content-hash collision for Drive file %s — identical content "
"exists in doc %s. Using unique_identifier_hash as content_hash.",
file_id,
dup.id,
)
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"file_name": file_name,
"mime_type": mime_type,
"document_type": "Google Drive File",
"connector_type": "Google Drive",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
indexable_content, user_llm, doc_metadata_for_summary
)
else:
logger.warning("No LLM configured — using fallback summary")
summary_content = (
f"Google Drive File: {file_name}\n\n{indexable_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(indexable_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=file_name,
document_type=DocumentType.GOOGLE_DRIVE_FILE,
document_metadata={
"google_drive_file_id": file_id,
"google_drive_file_name": file_name,
"google_drive_mime_type": mime_type,
"web_view_link": web_view_link,
"source_connector": "google_drive",
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
source_markdown=content,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, file=%s, chunks=%d",
document.id,
file_name,
len(chunks),
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
logger.warning(
"Duplicate constraint hit during KB sync for file %s. "
"Rolling back — periodic indexer will handle it. Error: %s",
file_id,
e,
)
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for file %s: %s",
file_id,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}

View file

@ -1,15 +1,21 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
from sqlalchemy import and_, func from sqlalchemy import and_, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.google_drive.client import GoogleDriveClient
from app.db import ( from app.db import (
Document, Document,
DocumentType, DocumentType,
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
) )
from app.utils.google_credentials import build_composio_credentials
logger = logging.getLogger(__name__)
@dataclass @dataclass
@ -68,12 +74,25 @@ class GoogleDriveToolMetadataService:
return { return {
"accounts": [], "accounts": [],
"supported_types": [], "supported_types": [],
"parent_folders": {},
"error": "No Google Drive account connected", "error": "No Google Drive account connected",
} }
accounts_with_status = []
for acc in accounts:
acc_dict = acc.to_dict()
auth_expired = await self._check_account_health(acc.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(acc.id)
accounts_with_status.append(acc_dict)
parent_folders = await self._get_parent_folders_by_account(accounts_with_status)
return { return {
"accounts": [acc.to_dict() for acc in accounts], "accounts": accounts_with_status,
"supported_types": ["google_doc", "google_sheet"], "supported_types": ["google_doc", "google_sheet"],
"parent_folders": parent_folders,
} }
async def get_trash_context( async def get_trash_context(
@ -92,6 +111,8 @@ class GoogleDriveToolMetadataService:
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
) )
) )
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
) )
document = result.scalars().first() document = result.scalars().first()
@ -112,8 +133,12 @@ class GoogleDriveToolMetadataService:
and_( and_(
SearchSourceConnector.id == document.connector_id, SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type.in_(
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
),
) )
) )
) )
@ -125,8 +150,14 @@ class GoogleDriveToolMetadataService:
account = GoogleDriveAccount.from_connector(connector) account = GoogleDriveAccount.from_connector(connector)
file = GoogleDriveFile.from_document(document) file = GoogleDriveFile.from_document(document)
acc_dict = account.to_dict()
auth_expired = await self._check_account_health(connector.id)
acc_dict["auth_expired"] = auth_expired
if auth_expired:
await self._persist_auth_expired(connector.id)
return { return {
"account": account.to_dict(), "account": acc_dict,
"file": file.to_dict(), "file": file.to_dict(),
} }
@ -139,11 +170,150 @@ class GoogleDriveToolMetadataService:
and_( and_(
SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type.in_(
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR, [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
),
) )
) )
.order_by(SearchSourceConnector.last_indexed_at.desc()) .order_by(SearchSourceConnector.last_indexed_at.desc())
) )
connectors = result.scalars().all() connectors = result.scalars().all()
return [GoogleDriveAccount.from_connector(c) for c in connectors] return [GoogleDriveAccount.from_connector(c) for c in connectors]
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Google Drive connector's credentials are still valid.
Uses a lightweight ``files.list(pageSize=1)`` call to verify access.
Returns True if the credentials are expired/invalid, False if healthy.
"""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
return True
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
credentials=pre_built_creds,
)
await client.list_files(
query="trashed = false", page_size=1, fields="files(id)"
)
return False
except Exception as e:
logger.warning(
"Google Drive connector %s health check failed: %s",
connector_id,
e,
)
return True
async def _persist_auth_expired(self, connector_id: int) -> None:
"""Persist ``auth_expired: True`` to the connector config if not already set."""
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
db_connector = result.scalar_one_or_none()
if db_connector and not db_connector.config.get("auth_expired"):
db_connector.config = {**db_connector.config, "auth_expired": True}
flag_modified(db_connector, "config")
await self._db_session.commit()
await self._db_session.refresh(db_connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector_id,
exc_info=True,
)
async def _get_parent_folders_by_account(
self, accounts_with_status: list[dict]
) -> dict[int, list[dict]]:
"""Fetch root-level folders for each healthy account.
Skips accounts where ``auth_expired`` is True so we don't waste an API
call that will fail anyway.
"""
parent_folders: dict[int, list[dict]] = {}
for acc in accounts_with_status:
connector_id = acc["id"]
if acc.get("auth_expired"):
parent_folders[connector_id] = []
continue
try:
result = await self._db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if not connector:
parent_folders[connector_id] = []
continue
pre_built_creds = None
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=self._db_session,
connector_id=connector_id,
credentials=pre_built_creds,
)
folders, _, error = await client.list_files(
query="mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
fields="files(id, name)",
page_size=50,
)
if error:
logger.warning(
"Failed to list folders for connector %s: %s",
connector_id,
error,
)
parent_folders[connector_id] = []
else:
parent_folders[connector_id] = [
{"folder_id": f["id"], "name": f["name"]}
for f in folders
if f.get("id") and f.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
connector_id,
exc_info=True,
)
parent_folders[connector_id] = []
return parent_folders

View file

@ -0,0 +1,13 @@
from app.services.jira.kb_sync_service import JiraKBSyncService
from app.services.jira.tool_metadata_service import (
JiraIssue,
JiraToolMetadataService,
JiraWorkspace,
)
__all__ = [
"JiraIssue",
"JiraKBSyncService",
"JiraToolMetadataService",
"JiraWorkspace",
]

View file

@ -0,0 +1,254 @@
import asyncio
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
class JiraKBSyncService:
"""Syncs Jira issue documents to the knowledge base after HITL actions."""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
issue_id: str,
issue_identifier: str,
issue_title: str,
description: str | None,
state: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Jira issue %s already exists (doc_id=%s), skipping",
issue_identifier,
existing.id,
)
return {"status": "success"}
indexable_content = (description or "").strip()
if not indexable_content:
indexable_content = f"Jira Issue {issue_identifier}: {issue_title}"
issue_content = (
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
)
content_hash = generate_content_hash(issue_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"issue_id": issue_identifier,
"issue_title": issue_title,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
issue_content, user_llm, doc_metadata_for_summary
)
else:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(issue_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.JIRA_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state or "Unknown",
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, issue=%s",
document.id,
issue_identifier,
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for issue %s: %s",
issue_identifier,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def sync_after_update(
self,
document_id: int,
issue_id: str,
user_id: str,
search_space_id: int,
) -> dict:
from app.tasks.connector_indexers.base import (
get_current_timestamp,
safe_set_chunks,
)
try:
document = await self.db_session.get(Document, document_id)
if not document:
return {"status": "not_indexed"}
connector_id = document.connector_id
if not connector_id:
return {"status": "error", "message": "Document has no connector_id"}
jira_history = JiraHistoryConnector(
session=self.db_session, connector_id=connector_id
)
jira_client = await jira_history._get_jira_client()
issue_raw = await asyncio.to_thread(jira_client.get_issue, issue_id)
formatted = jira_client.format_issue(issue_raw)
issue_content = jira_client.format_issue_to_markdown(formatted)
if not issue_content:
return {"status": "error", "message": "Issue produced empty content"}
issue_identifier = formatted.get("key", "")
issue_title = formatted.get("title", "")
state = formatted.get("status", "Unknown")
comment_count = len(formatted.get("comments", []))
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)
if user_llm:
doc_meta = {
"issue_key": issue_identifier,
"issue_title": issue_title,
"status": state,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
summary_content, summary_embedding = await generate_document_summary(
issue_content, user_llm, doc_meta
)
else:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(issue_content)
document.title = f"{issue_identifier}: {issue_title}"
document.content = summary_content
document.content_hash = generate_content_hash(
issue_content, search_space_id
)
document.embedding = summary_embedding
from sqlalchemy.orm.attributes import flag_modified
document.document_metadata = {
**(document.document_metadata or {}),
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"comment_count": comment_count,
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
flag_modified(document, "document_metadata")
await safe_set_chunks(self.db_session, document, chunks)
document.updated_at = get_current_timestamp()
await self.db_session.commit()
logger.info(
"KB sync successful for document %s (%s: %s)",
document_id,
issue_identifier,
issue_title,
)
return {"status": "success"}
except Exception as e:
logger.error(
"KB sync failed for document %s: %s", document_id, e, exc_info=True
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}

View file

@ -0,0 +1,332 @@
import asyncio
import logging
from dataclasses import dataclass
from sqlalchemy import and_, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.jira_history import JiraHistoryConnector
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
logger = logging.getLogger(__name__)
@dataclass
class JiraWorkspace:
"""Represents a Jira connector as a workspace for tool context."""
id: int
name: str
base_url: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "JiraWorkspace":
return cls(
id=connector.id,
name=connector.name,
base_url=connector.config.get("base_url", ""),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"base_url": self.base_url,
}
@dataclass
class JiraIssue:
"""Represents an indexed Jira issue resolved from the knowledge base."""
issue_id: str
issue_identifier: str
issue_title: str
state: str
connector_id: int
document_id: int
indexed_at: str | None
@classmethod
def from_document(cls, document: Document) -> "JiraIssue":
meta = document.document_metadata or {}
return cls(
issue_id=meta.get("issue_id", ""),
issue_identifier=meta.get("issue_identifier", ""),
issue_title=meta.get("issue_title", document.title),
state=meta.get("state", ""),
connector_id=document.connector_id,
document_id=document.id,
indexed_at=meta.get("indexed_at"),
)
def to_dict(self) -> dict:
return {
"issue_id": self.issue_id,
"issue_identifier": self.issue_identifier,
"issue_title": self.issue_title,
"state": self.state,
"connector_id": self.connector_id,
"document_id": self.document_id,
"indexed_at": self.indexed_at,
}
class JiraToolMetadataService:
"""Builds interrupt context for Jira HITL tools."""
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
"""Check if the Jira connector auth is still valid.
Returns True if auth is expired/invalid, False if healthy.
"""
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.get_myself)
return False
except Exception as e:
logger.warning("Jira connector %s health check failed: %s", connector.id, e)
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await self._db_session.commit()
await self._db_session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return True
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
"""Return context needed to create a new Jira issue.
Fetches all connected Jira accounts, and for the first healthy one
fetches projects, issue types, and priorities.
"""
connectors = await self._get_all_jira_connectors(search_space_id, user_id)
if not connectors:
return {"error": "No Jira account connected"}
accounts = []
projects = []
issue_types = []
priorities = []
fetched_context = False
for connector in connectors:
auth_expired = await self._check_account_health(connector)
workspace = JiraWorkspace.from_connector(connector)
account_info = {
**workspace.to_dict(),
"auth_expired": auth_expired,
}
accounts.append(account_info)
if not auth_expired and not fetched_context:
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
raw_projects = await asyncio.to_thread(jira_client.get_projects)
projects = [
{"id": p.get("id"), "key": p.get("key"), "name": p.get("name")}
for p in raw_projects
]
raw_types = await asyncio.to_thread(jira_client.get_issue_types)
seen_type_names: set[str] = set()
issue_types = []
for t in raw_types:
if t.get("subtask", False):
continue
name = t.get("name")
if name not in seen_type_names:
seen_type_names.add(name)
issue_types.append({"id": t.get("id"), "name": name})
raw_priorities = await asyncio.to_thread(jira_client.get_priorities)
priorities = [
{"id": p.get("id"), "name": p.get("name")}
for p in raw_priorities
]
fetched_context = True
except Exception as e:
logger.warning(
"Failed to fetch Jira context for connector %s: %s",
connector.id,
e,
)
return {
"accounts": accounts,
"projects": projects,
"issue_types": issue_types,
"priorities": priorities,
}
async def get_update_context(
self, search_space_id: int, user_id: str, issue_ref: str
) -> dict:
"""Return context needed to update an indexed Jira issue.
Resolves the issue from the KB, then fetches current details from the Jira API.
"""
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document:
return {
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
"Please make sure the issue is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = await self._check_account_health(connector)
if auth_expired:
return {
"error": "Jira authentication has expired. Please re-authenticate.",
"auth_expired": True,
"connector_id": connector.id,
}
workspace = JiraWorkspace.from_connector(connector)
issue = JiraIssue.from_document(document)
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
issue_data = await asyncio.to_thread(jira_client.get_issue, issue.issue_id)
formatted = jira_client.format_issue(issue_data)
except Exception as e:
error_str = str(e).lower()
if (
"401" in error_str
or "403" in error_str
or "authentication" in error_str
):
return {
"error": f"Failed to fetch Jira issue: {e!s}",
"auth_expired": True,
"connector_id": connector.id,
}
return {"error": f"Failed to fetch Jira issue: {e!s}"}
return {
"account": {**workspace.to_dict(), "auth_expired": False},
"issue": {
"issue_id": formatted.get("key", issue.issue_id),
"issue_identifier": formatted.get("key", issue.issue_identifier),
"issue_title": formatted.get("title", issue.issue_title),
"state": formatted.get("status", "Unknown"),
"priority": formatted.get("priority", "Unknown"),
"issue_type": formatted.get("issue_type", "Unknown"),
"assignee": formatted.get("assignee"),
"description": formatted.get("description"),
"project": formatted.get("project", ""),
"document_id": issue.document_id,
"indexed_at": issue.indexed_at,
},
}
async def get_deletion_context(
self, search_space_id: int, user_id: str, issue_ref: str
) -> dict:
"""Return context needed to delete a Jira issue (KB metadata only, no API call)."""
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document:
return {
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
"Please make sure the issue is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = connector.config.get("auth_expired", False)
workspace = JiraWorkspace.from_connector(connector)
issue = JiraIssue.from_document(document)
return {
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
"issue": issue.to_dict(),
}
async def _resolve_issue(
self, search_space_id: int, user_id: str, issue_ref: str
) -> Document | None:
"""Resolve an issue from KB: issue_identifier -> issue_title -> document.title."""
ref_lower = issue_ref.lower()
result = await self._db_session.execute(
select(Document)
.join(
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.JIRA_CONNECTOR,
SearchSourceConnector.user_id == user_id,
or_(
func.lower(
Document.document_metadata.op("->>")("issue_identifier")
)
== ref_lower,
func.lower(Document.document_metadata.op("->>")("issue_title"))
== ref_lower,
func.lower(Document.title) == ref_lower,
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
return result.scalars().first()
async def _get_all_jira_connectors(
self, search_space_id: int, user_id: str
) -> list[SearchSourceConnector]:
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
)
return result.scalars().all()
async def _get_connector_for_document(
self, document: Document, user_id: str
) -> SearchSourceConnector | None:
if not document.connector_id:
return None
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.user_id == user_id,
)
)
)
return result.scalars().first()

View file

@ -4,29 +4,174 @@ from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector from app.connectors.linear_connector import LinearConnector
from app.db import Document from app.db import Document, DocumentType
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import ( from app.utils.document_converters import (
create_document_chunks, create_document_chunks,
embed_text, embed_text,
generate_content_hash, generate_content_hash,
generate_document_summary, generate_document_summary,
generate_unique_identifier_hash,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LinearKBSyncService: class LinearKBSyncService:
"""Re-indexes a single Linear issue document after a successful update. """Syncs Linear issue documents to the knowledge base after HITL actions.
Mirrors the indexer's Phase-2 logic exactly: fetch fresh issue content, Provides sync_after_create (new issue) and sync_after_update (existing issue).
run generate_document_summary, create_document_chunks, then update the Both mirror the indexer's Phase-2 logic: generate summary, create chunks,
document row in the knowledge base. then persist the document row.
""" """
def __init__(self, db_session: AsyncSession): def __init__(self, db_session: AsyncSession):
self.db_session = db_session self.db_session = db_session
async def sync_after_create(
self,
issue_id: str,
issue_identifier: str,
issue_title: str,
issue_url: str | None,
description: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Linear issue %s already exists (doc_id=%s), skipping",
issue_identifier,
existing.id,
)
return {"status": "success"}
indexable_content = (description or "").strip()
if not indexable_content:
indexable_content = f"Linear Issue {issue_identifier}: {issue_title}"
issue_content = (
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
)
content_hash = generate_content_hash(issue_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
logger.info(
"Content-hash collision for Linear issue %s — identical content "
"exists in doc %s. Using unique_identifier_hash as content_hash.",
issue_identifier,
dup.id,
)
content_hash = unique_hash
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"issue_id": issue_identifier,
"issue_title": issue_title,
"document_type": "Linear Issue",
"connector_type": "Linear",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
issue_content, user_llm, doc_metadata_for_summary
)
else:
logger.warning("No LLM configured — using fallback summary")
summary_content = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(issue_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.LINEAR_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"issue_url": issue_url,
"source_connector": "linear",
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, issue=%s, chunks=%d",
document.id,
issue_identifier,
len(chunks),
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
logger.warning(
"Duplicate constraint hit during KB sync for issue %s. "
"Rolling back — periodic indexer will handle it. Error: %s",
issue_identifier,
e,
)
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for issue %s: %s",
issue_identifier,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def sync_after_update( async def sync_after_update(
self, self,
document_id: int, document_id: int,

View file

@ -1,8 +1,10 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
from sqlalchemy import and_, func, or_ from sqlalchemy import and_, func, or_
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.linear_connector import LinearConnector from app.connectors.linear_connector import LinearConnector
from app.db import ( from app.db import (
@ -12,6 +14,8 @@ from app.db import (
SearchSourceConnectorType, SearchSourceConnectorType,
) )
logger = logging.getLogger(__name__)
@dataclass @dataclass
class LinearWorkspace: class LinearWorkspace:
@ -109,7 +113,34 @@ class LinearToolMetadataService:
priorities = await self._fetch_priority_values(linear_client) priorities = await self._fetch_priority_values(linear_client)
teams = await self._fetch_teams_context(linear_client) teams = await self._fetch_teams_context(linear_client)
except Exception as e: except Exception as e:
return {"error": f"Failed to fetch Linear context: {e!s}"} logger.warning(
"Linear connector %s (%s) auth failed, flagging as expired: %s",
connector.id,
workspace.name,
e,
)
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await self._db_session.commit()
await self._db_session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
workspaces.append(
{
"id": workspace.id,
"name": workspace.name,
"organization_name": workspace.organization_name,
"teams": [],
"priorities": [],
"auth_expired": True,
}
)
continue
workspaces.append( workspaces.append(
{ {
"id": workspace.id, "id": workspace.id,
@ -117,6 +148,7 @@ class LinearToolMetadataService:
"organization_name": workspace.organization_name, "organization_name": workspace.organization_name,
"teams": teams, "teams": teams,
"priorities": priorities, "priorities": priorities,
"auth_expired": False,
} }
) )
@ -137,8 +169,8 @@ class LinearToolMetadataService:
document = await self._resolve_issue(search_space_id, user_id, issue_ref) document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document: if not document:
return { return {
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. " "error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, " "This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
"or (3) the title or identifier is different." "or (3) the title or identifier is different."
} }
@ -157,6 +189,17 @@ class LinearToolMetadataService:
priorities = await self._fetch_priority_values(linear_client) priorities = await self._fetch_priority_values(linear_client)
issue_api = await self._fetch_issue_context(linear_client, issue.id) issue_api = await self._fetch_issue_context(linear_client, issue.id)
except Exception as e: except Exception as e:
error_str = str(e).lower()
if (
"401" in error_str
or "authentication" in error_str
or "re-authenticate" in error_str
):
return {
"error": f"Failed to fetch Linear issue context: {e!s}",
"auth_expired": True,
"connector_id": connector.id,
}
return {"error": f"Failed to fetch Linear issue context: {e!s}"} return {"error": f"Failed to fetch Linear issue context: {e!s}"}
if not issue_api: if not issue_api:
@ -210,8 +253,8 @@ class LinearToolMetadataService:
document = await self._resolve_issue(search_space_id, user_id, issue_ref) document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document: if not document:
return { return {
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. " "error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, " "This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
"or (3) the title or identifier is different." "or (3) the title or identifier is different."
} }
@ -319,6 +362,7 @@ class LinearToolMetadataService:
), ),
) )
) )
.order_by(Document.updated_at.desc().nullslast())
.limit(1) .limit(1)
) )
return result.scalars().first() return result.scalars().first()

Some files were not shown because too many files have changed in this diff Show more