mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
Merge pull request #966 from MODSetter/dev
feat: HITL Workflows and Fixing Real-Time Sync
This commit is contained in:
commit
8227d1852f
320 changed files with 33857 additions and 19630 deletions
2
.github/workflows/desktop-release.yml
vendored
2
.github/workflows/desktop-release.yml
vendored
|
|
@ -57,7 +57,7 @@ jobs:
|
|||
working-directory: surfsense_web
|
||||
env:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${{ vars.NEXT_PUBLIC_ELECTRIC_URL }}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}
|
||||
|
||||
|
|
|
|||
3
.github/workflows/docker-build.yml
vendored
3
.github/workflows/docker-build.yml
vendored
|
|
@ -164,8 +164,7 @@ jobs:
|
|||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
|
||||
|
||||
- name: Export digest
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
|
||||
# BACKEND_PORT=8929
|
||||
# FRONTEND_PORT=3929
|
||||
# ELECTRIC_PORT=5929
|
||||
# ZERO_CACHE_PORT=5929
|
||||
# SEARXNG_PORT=8888
|
||||
# FLOWER_PORT=5555
|
||||
|
||||
|
|
@ -58,7 +58,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
|
||||
# NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / Reverse Proxy
|
||||
|
|
@ -71,8 +70,35 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# NEXT_FRONTEND_URL=https://app.yourdomain.com
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_ELECTRIC_URL=https://electric.yourdomain.com
|
||||
# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Zero-cache (real-time sync)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Defaults work out of the box for Docker deployments.
|
||||
# Change ZERO_ADMIN_PASSWORD for security in production.
|
||||
|
||||
# ZERO_ADMIN_PASSWORD=surfsense-zero-admin
|
||||
# Full override for the Zero → Postgres connection URLs.
|
||||
# Leave commented out to use the Docker-managed `db` container (default).
|
||||
# ZERO_UPSTREAM_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
# ZERO_CVR_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
# ZERO_CHANGE_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
|
||||
# ZERO_QUERY_URL: where zero-cache forwards query requests for resolution.
|
||||
# ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# Default: Docker service networking (http://frontend:3000/api/zero/...).
|
||||
# Override when running the frontend outside Docker:
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain:
|
||||
# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Database (defaults work out of the box, change for security)
|
||||
|
|
@ -101,19 +127,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# Supports TLS: rediss://:password@host:6380/0
|
||||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Electric SQL (real-time sync credentials)
|
||||
# ------------------------------------------------------------------------------
|
||||
# These must match on the db, backend, and electric services.
|
||||
# Change for security; defaults work out of the box.
|
||||
|
||||
# ELECTRIC_DB_USER=electric
|
||||
# ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Full override for the Electric → Postgres connection URL.
|
||||
# Leave commented out to use the Docker-managed `db` container (default).
|
||||
# Uncomment and set `db` to `host.docker.internal` when pointing Electric at a local Postgres instance (e.g. Postgres.app on macOS):
|
||||
# ELECTRIC_DATABASE_URL=postgresql://electric:electric_password@db:5432/surfsense?sslmode=disable
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -18,13 +18,10 @@ services:
|
|||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
- POSTGRES_USER=${DB_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${DB_NAME:-surfsense}
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
|
||||
|
|
@ -91,8 +88,6 @@ services:
|
|||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
|
||||
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
|
|
@ -130,8 +125,6 @@ services:
|
|||
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
- SERVICE_ROLE=worker
|
||||
depends_on:
|
||||
|
|
@ -176,20 +169,28 @@ services:
|
|||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
zero-cache:
|
||||
image: rocicorp/zero:0.26.2
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
- "${ZERO_CACHE_PORT:-4848}:4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ELECTRIC_INSECURE=true
|
||||
- ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
- ZERO_UPSTREAM_DB=${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_CVR_DB=${ZERO_CVR_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_CHANGE_DB=${ZERO_CHANGE_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_REPLICA_FILE=/data/zero.db
|
||||
- ZERO_ADMIN_PASSWORD=${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
|
||||
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
|
@ -201,8 +202,7 @@ services:
|
|||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
|
|
@ -211,7 +211,7 @@ services:
|
|||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
|
|
@ -223,3 +223,5 @@ volumes:
|
|||
name: surfsense-dev-redis
|
||||
shared_temp:
|
||||
name: surfsense-dev-shared-temp
|
||||
zero_cache_data:
|
||||
name: surfsense-dev-zero-cache
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ services:
|
|||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USER:-surfsense}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense}
|
||||
POSTGRES_DB: ${DB_NAME:-surfsense}
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
|
|
@ -72,8 +69,6 @@ services:
|
|||
PYTHONPATH: /app
|
||||
UVICORN_LOOP: asyncio
|
||||
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
|
|
@ -112,8 +107,6 @@ services:
|
|||
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
SERVICE_ROLE: worker
|
||||
depends_on:
|
||||
|
|
@ -165,20 +158,28 @@ services:
|
|||
# - celery_worker
|
||||
# restart: unless-stopped
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
zero-cache:
|
||||
image: rocicorp/zero:0.26.2
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5929}:3000"
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
DATABASE_URL: ${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ELECTRIC_INSECURE: "true"
|
||||
ELECTRIC_WRITE_TO_PG_MODE: direct
|
||||
ZERO_UPSTREAM_DB: ${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_CVR_DB: ${ZERO_CVR_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_CHANGE_DB: ${ZERO_CHANGE_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_REPLICA_FILE: /data/zero.db
|
||||
ZERO_ADMIN_PASSWORD: ${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
|
||||
ZERO_QUERY_URL: ${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
ZERO_MUTATE_URL: ${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
|
@ -189,17 +190,16 @@ services:
|
|||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
environment:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:${ELECTRIC_PORT:-5929}}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
|
|
@ -210,3 +210,5 @@ volumes:
|
|||
name: surfsense-redis
|
||||
shared_temp:
|
||||
name: surfsense-shared-temp
|
||||
zero_cache_data:
|
||||
name: surfsense-zero-cache
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# PostgreSQL configuration for Electric SQL
|
||||
# PostgreSQL configuration for SurfSense
|
||||
# This file is mounted into the PostgreSQL container
|
||||
|
||||
listen_addresses = '*'
|
||||
max_connections = 200
|
||||
shared_buffers = 256MB
|
||||
|
||||
# Enable logical replication (required for Electric SQL)
|
||||
# Enable logical replication (required for Zero-cache real-time sync)
|
||||
wal_level = logical
|
||||
max_replication_slots = 10
|
||||
max_wal_senders = 10
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
#!/bin/sh
|
||||
# Creates the Electric SQL replication user on first DB initialization.
|
||||
# Idempotent — safe to run alongside Alembic migration 66.
|
||||
|
||||
set -e
|
||||
|
||||
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
|
||||
echo "Creating Electric SQL replication user: $ELECTRIC_DB_USER"
|
||||
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '$ELECTRIC_DB_USER') THEN
|
||||
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
|
||||
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT CREATE ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
EOSQL
|
||||
|
||||
echo "Electric SQL user '$ELECTRIC_DB_USER' and publication created successfully"
|
||||
|
|
@ -109,7 +109,6 @@ $Files = @(
|
|||
@{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" }
|
||||
@{ Src = "docker/.env.example"; Dest = ".env.example" }
|
||||
@{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" }
|
||||
@{ Src = "docker/scripts/init-electric-user.sh"; Dest = "scripts/init-electric-user.sh" }
|
||||
@{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" }
|
||||
@{ Src = "docker/searxng/settings.yml"; Dest = "searxng/settings.yml" }
|
||||
@{ Src = "docker/searxng/limiter.toml"; Dest = "searxng/limiter.toml" }
|
||||
|
|
|
|||
|
|
@ -108,7 +108,6 @@ FILES=(
|
|||
"docker/docker-compose.yml:docker-compose.yml"
|
||||
"docker/.env.example:.env.example"
|
||||
"docker/postgresql.conf:postgresql.conf"
|
||||
"docker/scripts/init-electric-user.sh:scripts/init-electric-user.sh"
|
||||
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
|
||||
"docker/searxng/settings.yml:searxng/settings.yml"
|
||||
"docker/searxng/limiter.toml:searxng/limiter.toml"
|
||||
|
|
@ -122,7 +121,6 @@ for entry in "${FILES[@]}"; do
|
|||
|| error "Failed to download ${dest}. Check your internet connection and try again."
|
||||
done
|
||||
|
||||
chmod +x "${INSTALL_DIR}/scripts/init-electric-user.sh"
|
||||
chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh"
|
||||
success "All files downloaded to ${INSTALL_DIR}/"
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,6 @@ REDIS_APP_URL=redis://localhost:6379/0
|
|||
# Only uncomment if running the backend outside Docker (e.g. uvicorn on host).
|
||||
# SEARXNG_DEFAULT_HOST=http://localhost:8888
|
||||
|
||||
#Electric(for migrations only)
|
||||
ELECTRIC_DB_USER=electric
|
||||
ELECTRIC_DB_PASSWORD=electric_password
|
||||
|
||||
# Periodic task interval
|
||||
# # Run every minute (default)
|
||||
# SCHEDULE_CHECKER_INTERVAL=1m
|
||||
|
|
@ -104,7 +100,8 @@ TEAMS_CLIENT_ID=your_teams_client_id_here
|
|||
TEAMS_CLIENT_SECRET=your_teams_client_secret_here
|
||||
TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
|
||||
#Composio Coonnector
|
||||
# Composio Connector
|
||||
# NOTE: Disable "Mask Connected Account Secrets" in Composio dashboard (Settings → Project Settings) for Google indexing to work.
|
||||
COMPOSIO_API_KEY=your_api_key_here
|
||||
COMPOSIO_ENABLED=TRUE
|
||||
COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
|
|
|||
|
|
@ -25,13 +25,6 @@ database_url = os.getenv("DATABASE_URL")
|
|||
if database_url:
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Electric SQL user credentials - centralized configuration for migrations
|
||||
# These are used by migrations that set up Electric SQL replication
|
||||
config.set_main_option("electric_db_user", os.getenv("ELECTRIC_DB_USER", "electric"))
|
||||
config.set_main_option(
|
||||
"electric_db_password", os.getenv("ELECTRIC_DB_PASSWORD", "electric_password")
|
||||
)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
|
|
|||
|
|
@ -30,21 +30,25 @@ def upgrade() -> None:
|
|||
"ix_notifications_user_read_type_created",
|
||||
"notifications",
|
||||
["user_id", "read", "type", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_user_space_created",
|
||||
"notifications",
|
||||
["user_id", "search_space_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_type",
|
||||
"notifications",
|
||||
["type"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_search_space_id",
|
||||
"notifications",
|
||||
["search_space_id"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,52 +35,60 @@ def upgrade() -> None:
|
|||
END $$;
|
||||
""")
|
||||
|
||||
op.create_table(
|
||||
"video_presentations",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("title", sa.String(length=500), nullable=False),
|
||||
sa.Column("slides", JSONB(), nullable=True),
|
||||
sa.Column("scene_codes", JSONB(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
video_presentation_status_enum,
|
||||
server_default="ready",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("thread_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"],
|
||||
["searchspaces.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["thread_id"],
|
||||
["new_chat_threads.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text("SELECT 1 FROM information_schema.tables WHERE table_name = 'video_presentations'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.create_table(
|
||||
"video_presentations",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("title", sa.String(length=500), nullable=False),
|
||||
sa.Column("slides", JSONB(), nullable=True),
|
||||
sa.Column("scene_codes", JSONB(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
video_presentation_status_enum,
|
||||
server_default="ready",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("thread_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"],
|
||||
["searchspaces.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["thread_id"],
|
||||
["new_chat_threads.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_status",
|
||||
"video_presentations",
|
||||
["status"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_thread_id",
|
||||
"video_presentations",
|
||||
["thread_id"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_created_at",
|
||||
"video_presentations",
|
||||
["created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
"""Clean up Electric SQL artifacts (user, publication, replication slots)
|
||||
|
||||
Revision ID: 108
|
||||
Revises: 107
|
||||
|
||||
Removes leftover Electric SQL infrastructure that is no longer needed after
|
||||
the migration to Rocicorp Zero. Fully idempotent — safe on databases that
|
||||
never had Electric SQL set up (fresh installs).
|
||||
|
||||
Cleaned up:
|
||||
- Replication slots containing 'electric' (prevents unbounded WAL growth)
|
||||
- The 'electric_publication_default' publication
|
||||
- Default privileges, grants, and the 'electric' database user
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "108"
|
||||
down_revision: str | None = "107"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
slot RECORD;
|
||||
BEGIN
|
||||
-- 1. Drop inactive Electric replication slots (prevents WAL growth)
|
||||
FOR slot IN
|
||||
SELECT slot_name FROM pg_replication_slots
|
||||
WHERE slot_name LIKE '%electric%' AND active = false
|
||||
LOOP
|
||||
BEGIN
|
||||
PERFORM pg_drop_replication_slot(slot.slot_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop replication slot %: %', slot.slot_name, SQLERRM;
|
||||
END;
|
||||
END LOOP;
|
||||
|
||||
-- Warn about active Electric slots that cannot be safely dropped
|
||||
FOR slot IN
|
||||
SELECT slot_name FROM pg_replication_slots
|
||||
WHERE slot_name LIKE '%electric%' AND active = true
|
||||
LOOP
|
||||
RAISE WARNING 'Active Electric replication slot "%" was not dropped — drop it manually to stop WAL growth', slot.slot_name;
|
||||
END LOOP;
|
||||
|
||||
-- 2. Drop the Electric publication
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
DROP PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop publication electric_publication_default: %', SQLERRM;
|
||||
END;
|
||||
|
||||
-- 3. Revoke privileges and drop the Electric user
|
||||
IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'electric') THEN
|
||||
BEGIN
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public
|
||||
REVOKE SELECT ON TABLES FROM electric;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public
|
||||
REVOKE SELECT ON SEQUENCES FROM electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke default privileges from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM electric;
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM electric;
|
||||
REVOKE USAGE ON SCHEMA public FROM electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke schema privileges from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
'REVOKE CONNECT ON DATABASE %I FROM electric',
|
||||
current_database()
|
||||
);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke CONNECT from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
REASSIGN OWNED BY electric TO CURRENT_USER;
|
||||
DROP ROLE electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop role electric: %', SQLERRM;
|
||||
END;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
|
@ -5,7 +5,7 @@ This module provides the SurfSense deep agent with configurable tools
|
|||
for knowledge base search, podcast generation, and more.
|
||||
|
||||
Directory Structure:
|
||||
- tools/: All agent tools (knowledge_base, podcast, link_preview, etc.)
|
||||
- tools/: All agent tools (knowledge_base, podcast, generate_image, etc.)
|
||||
- chat_deepagent.py: Main agent factory
|
||||
- system_prompt.py: System prompts and instructions
|
||||
- context.py: Context schema for the agent
|
||||
|
|
@ -37,9 +37,7 @@ from .tools import (
|
|||
BUILTIN_TOOLS,
|
||||
ToolDefinition,
|
||||
build_tools,
|
||||
create_display_image_tool,
|
||||
create_generate_podcast_tool,
|
||||
create_link_preview_tool,
|
||||
create_scrape_webpage_tool,
|
||||
create_search_knowledge_base_tool,
|
||||
format_documents_for_context,
|
||||
|
|
@ -63,9 +61,7 @@ __all__ = [
|
|||
# LLM config
|
||||
"create_chat_litellm_from_config",
|
||||
# Tool factories
|
||||
"create_display_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_scrape_webpage_tool",
|
||||
"create_search_knowledge_base_tool",
|
||||
# Agent factory
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
|
|
@ -65,10 +68,11 @@ _CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
|
||||
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
|
||||
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
# Composio connectors (unified to native document types).
|
||||
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
# Document types that don't come from SearchSourceConnector but should always be searchable
|
||||
|
|
@ -146,8 +150,6 @@ async def create_surfsense_deep_agent(
|
|||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
|
|
@ -203,7 +205,7 @@ async def create_surfsense_deep_agent(
|
|||
# Create agent with only specific tools
|
||||
agent = create_surfsense_deep_agent(
|
||||
llm, search_space_id, db_session, ...,
|
||||
enabled_tools=["search_knowledge_base", "link_preview"]
|
||||
enabled_tools=["search_knowledge_base", "scrape_webpage"]
|
||||
)
|
||||
|
||||
# Create agent without podcast generation
|
||||
|
|
@ -292,6 +294,69 @@ async def create_surfsense_deep_agent(
|
|||
]
|
||||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
||||
has_google_drive_connector = (
|
||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_google_drive_connector:
|
||||
google_drive_tools = [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_google_calendar_connector:
|
||||
calendar_tools = [
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
]
|
||||
modified_disabled_tools.extend(calendar_tools)
|
||||
|
||||
# Disable Gmail action tools if no Gmail connector is configured
|
||||
has_gmail_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_gmail_connector:
|
||||
gmail_tools = [
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
]
|
||||
modified_disabled_tools.extend(gmail_tools)
|
||||
|
||||
# Disable Jira action tools if no Jira connector is configured
|
||||
has_jira_connector = (
|
||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_jira_connector:
|
||||
jira_tools = [
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(jira_tools)
|
||||
|
||||
# Disable Confluence action tools if no Confluence connector is configured
|
||||
has_confluence_connector = (
|
||||
available_connectors is not None
|
||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_confluence_connector:
|
||||
confluence_tools = [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
]
|
||||
modified_disabled_tools.extend(confluence_tools)
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
|
|
@ -345,6 +410,7 @@ async def create_surfsense_deep_agent(
|
|||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
middleware=[DedupHITLToolCallsMiddleware()],
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"delete_google_drive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
|
@ -184,48 +184,6 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
|
|||
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["link_preview"] = """
|
||||
- link_preview: Fetch metadata for a URL to display a rich preview card.
|
||||
- IMPORTANT: Use this tool WHENEVER the user shares or mentions a URL/link in their message.
|
||||
- This fetches the page's Open Graph metadata (title, description, thumbnail) to show a preview card.
|
||||
- NOTE: This tool only fetches metadata, NOT the full page content. It cannot read the article text.
|
||||
- Trigger scenarios:
|
||||
* User shares a URL (e.g., "Check out https://example.com")
|
||||
* User pastes a link in their message
|
||||
* User asks about a URL or link
|
||||
- Args:
|
||||
- url: The URL to fetch metadata for (must be a valid HTTP/HTTPS URL)
|
||||
- Returns: A rich preview card with title, description, thumbnail, and domain
|
||||
- The preview card will automatically be displayed in the chat.
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["display_image"] = """
|
||||
- display_image: Display an image in the chat with metadata.
|
||||
- Use this tool ONLY when you have a valid public HTTP/HTTPS image URL to show.
|
||||
- This displays the image with an optional title, description, and source attribution.
|
||||
- Valid use cases:
|
||||
* Showing an image from a URL the user explicitly mentioned in their message
|
||||
* Displaying images found in scraped webpage content (from scrape_webpage tool)
|
||||
* Showing a publicly accessible diagram or chart from a known URL
|
||||
* Displaying an AI-generated image after calling the generate_image tool (ALWAYS required)
|
||||
|
||||
CRITICAL - NEVER USE THIS TOOL FOR USER-UPLOADED ATTACHMENTS:
|
||||
When a user uploads/attaches an image file to their message:
|
||||
* The image is ALREADY VISIBLE in the chat UI as a thumbnail on their message
|
||||
* You do NOT have a URL for their uploaded image - only extracted text/description
|
||||
* Calling display_image will FAIL and show "Image not available" error
|
||||
* Simply analyze the image content and respond with your analysis - DO NOT try to display it
|
||||
* The user can already see their own uploaded image - they don't need you to show it again
|
||||
|
||||
- Args:
|
||||
- src: The URL of the image (MUST be a valid public HTTP/HTTPS URL that you know exists)
|
||||
- alt: Alternative text describing the image (for accessibility)
|
||||
- title: Optional title to display below the image
|
||||
- description: Optional description providing context about the image
|
||||
- Returns: An image card with the image, title, and description
|
||||
- The image will automatically be displayed in the chat.
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["generate_image"] = """
|
||||
- generate_image: Generate images from text descriptions using AI image models.
|
||||
- Use this when the user asks you to create, generate, draw, design, or make an image.
|
||||
|
|
@ -233,10 +191,7 @@ _TOOL_INSTRUCTIONS["generate_image"] = """
|
|||
- Args:
|
||||
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
|
||||
- n: Number of images to generate (1-4, default: 1)
|
||||
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata.
|
||||
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
|
||||
to actually show the image in the chat. The generate_image tool only generates the image and returns
|
||||
the URL — it does NOT display anything. You must always follow up with display_image.
|
||||
- Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat.
|
||||
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
|
||||
expand and improve the prompt with specific details about style, lighting, composition, and mood.
|
||||
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
|
||||
|
|
@ -245,14 +200,11 @@ _TOOL_INSTRUCTIONS["generate_image"] = """
|
|||
_TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
||||
- scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
|
||||
- IMPORTANT: This is different from link_preview:
|
||||
* link_preview: Only fetches metadata (title, description, thumbnail) for display
|
||||
* scrape_webpage: Actually reads the FULL page content so you can analyze/summarize it
|
||||
- CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying):
|
||||
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
||||
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
||||
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
||||
* When link_preview or search_knowledge_base returned insufficient data and the user wants more
|
||||
* When search_knowledge_base returned insufficient data and the user wants more
|
||||
- Trigger scenarios:
|
||||
* "Read this article and summarize it"
|
||||
* "What does this page say about X?"
|
||||
|
|
@ -268,9 +220,10 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
|||
- url: The URL of the webpage to scrape (must be HTTP/HTTPS)
|
||||
- max_length: Maximum content length to return (default: 50000 chars)
|
||||
- Returns: The page title, description, full content (in markdown), word count, and metadata
|
||||
- After scraping, you will have the full article text and can analyze, summarize, or answer questions about it.
|
||||
- After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points.
|
||||
- Reference the source using markdown links [descriptive text](url) — never bare URLs.
|
||||
- IMAGES: The scraped content may contain image URLs in markdown format like ``.
|
||||
* When you find relevant/important images in the scraped content, use the `display_image` tool to show them to the user.
|
||||
* When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: ``.
|
||||
* This makes your response more visual and engaging.
|
||||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
|
@ -292,6 +245,8 @@ _TOOL_INSTRUCTIONS["web_search"] = """
|
|||
- Args:
|
||||
- query: The search query - use specific, descriptive terms
|
||||
- top_k: Number of results to retrieve (default: 10, max: 50)
|
||||
- If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content.
|
||||
- When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs.
|
||||
"""
|
||||
|
||||
# Memory tool instructions have private and shared variants.
|
||||
|
|
@ -476,32 +431,31 @@ _TOOL_EXAMPLES["generate_report"] = """
|
|||
|
||||
_TOOL_EXAMPLES["scrape_webpage"] = """
|
||||
- User: "Check out https://dev.to/some-article"
|
||||
- Call: `link_preview(url="https://dev.to/some-article")`
|
||||
- Call: `scrape_webpage(url="https://dev.to/some-article")`
|
||||
- Then provide your analysis of the content.
|
||||
- Respond with a structured analysis — key points, takeaways.
|
||||
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
|
||||
- Then provide a summary based on the scraped text.
|
||||
- Respond with a thorough summary using headings and bullet points.
|
||||
- User: (after discussing https://example.com/stats) "Can you get the live data from that page?"
|
||||
- Call: `scrape_webpage(url="https://example.com/stats")`
|
||||
- IMPORTANT: Always attempt scraping first. Never refuse before trying the tool.
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["display_image"] = """
|
||||
- User: "Show me this image: https://example.com/image.png"
|
||||
- Call: `display_image(src="https://example.com/image.png", alt="User shared image")`
|
||||
- User uploads an image file and asks: "What is this image about?"
|
||||
- DO NOT call display_image! The user's uploaded image is already visible in the chat.
|
||||
- Simply analyze the image content and respond directly.
|
||||
- User: "https://example.com/blog/weekend-recipes"
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")`
|
||||
- When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content.
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["generate_image"] = """
|
||||
- User: "Generate an image of a cat"
|
||||
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")`
|
||||
- Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- The generated image will automatically be displayed in the chat.
|
||||
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
|
||||
- Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- The generated image will automatically be displayed in the chat.
|
||||
- User: "Show me this image: https://example.com/image.png"
|
||||
- Simply include it in your response using markdown: ``
|
||||
- User uploads an image file and asks: "What is this image about?"
|
||||
- The user's uploaded image is already visible in the chat.
|
||||
- Simply analyze the image content and respond directly.
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["web_search"] = """
|
||||
|
|
@ -522,8 +476,6 @@ _ALL_TOOL_NAMES_ORDERED = [
|
|||
"generate_podcast",
|
||||
"generate_video_presentation",
|
||||
"generate_report",
|
||||
"link_preview",
|
||||
"display_image",
|
||||
"generate_image",
|
||||
"scrape_webpage",
|
||||
"save_memory",
|
||||
|
|
@ -764,7 +716,7 @@ Do not use the sandbox for:
|
|||
|
||||
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
|
||||
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
||||
- **DO NOT use markdown image syntax** for files created inside the sandbox. Sandbox files are not accessible via public URLs and will show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
||||
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- Always describe what the file contains in your response text so the user knows what they are downloading.
|
||||
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ Available tools:
|
|||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_video_presentation: Generate video presentations with slides and narration
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
- scrape_webpage: Extract content from webpages
|
||||
- save_memory: Store facts/preferences about the user
|
||||
- recall_memory: Retrieve relevant user memories
|
||||
|
|
@ -19,7 +17,6 @@ Available tools:
|
|||
|
||||
# Registry exports
|
||||
# Tool factory exports (for direct use)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
|
|
@ -27,7 +24,6 @@ from .knowledge_base import (
|
|||
format_documents_for_context,
|
||||
search_knowledge_base_async,
|
||||
)
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .registry import (
|
||||
BUILTIN_TOOLS,
|
||||
|
|
@ -50,11 +46,9 @@ __all__ = [
|
|||
"ToolDefinition",
|
||||
"build_tools",
|
||||
# Tool factories
|
||||
"create_display_image_tool",
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_recall_memory_tool",
|
||||
"create_save_memory_tool",
|
||||
"create_scrape_webpage_tool",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
"""Confluence tools for creating, updating, and deleting pages."""
|
||||
|
||||
from .create_page import create_create_confluence_page_tool
|
||||
from .delete_page import create_delete_confluence_page_tool
|
||||
from .update_page import create_update_confluence_page_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_confluence_page_tool",
|
||||
"create_delete_confluence_page_tool",
|
||||
"create_update_confluence_page_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_confluence_page(
|
||||
title: str,
|
||||
content: str | None = None,
|
||||
space_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Confluence.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Confluence page.
|
||||
|
||||
Args:
|
||||
title: Title of the page.
|
||||
content: Optional HTML/storage format content for the page body.
|
||||
space_id: Optional Confluence space ID to create the page in.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, page_id, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(f"create_confluence_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_creation",
|
||||
"action": {
|
||||
"tool": "create_confluence_page",
|
||||
"params": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_content = final_params.get("content", content) or ""
|
||||
final_space_id = final_params.get("space_id", space_id)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the page.",
|
||||
}
|
||||
|
||||
return create_confluence_page
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_confluence_page(
|
||||
page_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Confluence page.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_deletion",
|
||||
"action": {
|
||||
"tool": "delete_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the page.",
|
||||
}
|
||||
|
||||
return delete_confluence_page
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_confluence_page(
|
||||
page_title_or_id: str,
|
||||
new_title: str | None = None,
|
||||
new_content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page.
|
||||
|
||||
Use this tool when the user asks to modify or edit a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
new_title: Optional new title for the page.
|
||||
new_content: Optional new HTML/storage format content.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_update",
|
||||
"action": {
|
||||
"tool": "update_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_title = final_params.get("new_title", new_title) or current_title
|
||||
final_content = final_params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = final_params.get("version", current_version)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the page.",
|
||||
}
|
||||
|
||||
return update_confluence_page
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
"""
|
||||
Display image tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool for displaying images in the chat UI
|
||||
with metadata like title, description, and source attribution.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
"""Extract the domain from a URL."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc
|
||||
# Remove 'www.' prefix if present
|
||||
if domain.startswith("www."):
|
||||
domain = domain[4:]
|
||||
return domain
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def generate_image_id(src: str) -> str:
|
||||
"""Generate a unique ID for an image."""
|
||||
hash_val = hashlib.md5(src.encode()).hexdigest()[:12]
|
||||
return f"image-{hash_val}"
|
||||
|
||||
|
||||
def create_display_image_tool():
|
||||
"""
|
||||
Factory function to create the display_image tool.
|
||||
|
||||
Returns:
|
||||
A configured tool function for displaying images.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def display_image(
|
||||
src: str,
|
||||
alt: str = "Image",
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Display an image in the chat with metadata.
|
||||
|
||||
Use this tool when you want to show an image to the user.
|
||||
This displays the image with an optional title, description,
|
||||
and source attribution.
|
||||
|
||||
Common use cases:
|
||||
- Showing an image from a URL the user mentioned
|
||||
- Displaying a diagram or chart you're referencing
|
||||
- Showing example images when explaining concepts
|
||||
|
||||
Args:
|
||||
src: The URL of the image to display (must be a valid HTTP/HTTPS URL)
|
||||
alt: Alternative text describing the image (for accessibility)
|
||||
title: Optional title to display below the image
|
||||
description: Optional description providing context about the image
|
||||
|
||||
Returns:
|
||||
A dictionary containing image metadata for the UI to render:
|
||||
- id: Unique identifier for this image
|
||||
- assetId: The image URL (for deduplication)
|
||||
- src: The image URL
|
||||
- alt: Alt text for accessibility
|
||||
- title: Image title (if provided)
|
||||
- description: Image description (if provided)
|
||||
- domain: Source domain
|
||||
"""
|
||||
image_id = generate_image_id(src)
|
||||
|
||||
# Ensure URL has protocol
|
||||
if not src.startswith(("http://", "https://")):
|
||||
src = f"https://{src}"
|
||||
|
||||
domain = extract_domain(src)
|
||||
|
||||
# Determine aspect ratio based on image source
|
||||
# AI-generated images should use "auto" to preserve their native ratio
|
||||
is_generated = "/image-generations/" in src
|
||||
if is_generated:
|
||||
ratio = "auto"
|
||||
domain = "ai-generated"
|
||||
elif "unsplash.com" in src or "pexels.com" in src:
|
||||
ratio = "16:9"
|
||||
elif (
|
||||
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
|
||||
):
|
||||
ratio = "auto"
|
||||
else:
|
||||
ratio = "auto"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
"assetId": src,
|
||||
"src": src,
|
||||
"alt": alt,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"domain": domain,
|
||||
"ratio": ratio,
|
||||
}
|
||||
|
||||
return display_image
|
||||
|
|
@ -2,8 +2,7 @@
|
|||
Image generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool that generates images using litellm.aimage_generation()
|
||||
and returns the result via the existing display_image tool format so the frontend
|
||||
renders the generated image inline in the chat.
|
||||
and returns the result directly in a format the frontend Image component can render.
|
||||
|
||||
Config resolution:
|
||||
1. Uses the search space's image_generation_config_id preference
|
||||
|
|
@ -11,6 +10,7 @@ Config resolution:
|
|||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -222,11 +222,17 @@ def create_generate_image_tool(
|
|||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
"assetId": image_url,
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"domain": "ai-generated",
|
||||
"ratio": "auto",
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.trash_email import (
|
||||
create_trash_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.update_draft import (
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_gmail_draft(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a draft email in Gmail.
|
||||
|
||||
Use when the user asks to draft, compose, or prepare an email without
|
||||
sending it.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Draft an email to alice@example.com about the meeting"
|
||||
- "Compose a reply to Bob about the project update"
|
||||
"""
|
||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_creation",
|
||||
"action": {
|
||||
"tool": "create_gmail_draft",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return create_gmail_draft
|
||||
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_gmail_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send an email via Gmail.
|
||||
|
||||
Use when the user explicitly asks to send an email. This sends the
|
||||
email immediately - it cannot be unsent.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- thread_id: Gmail thread ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Send an email to alice@example.com about the meeting"
|
||||
- "Email Bob the project update"
|
||||
"""
|
||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_send",
|
||||
"action": {
|
||||
"tool": "send_gmail_email",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while sending the email. Please try again.",
|
||||
}
|
||||
|
||||
return send_gmail_email
|
||||
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_trash_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def trash_gmail_email(
|
||||
email_subject_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move an email or draft to trash in Gmail.
|
||||
|
||||
Use when the user asks to delete, remove, or trash an email or draft.
|
||||
|
||||
Args:
|
||||
email_subject_or_id: The exact subject line or message ID of the
|
||||
email to trash (as it appears in the inbox).
|
||||
delete_from_kb: Whether to also remove the email from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Gmail and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the email subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
Examples:
|
||||
- "Delete the email about 'Meeting Cancelled'"
|
||||
- "Trash the email from Bob about the project"
|
||||
"""
|
||||
logger.info(
|
||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_trash",
|
||||
"action": {
|
||||
"tool": "trash_gmail_email",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_message_id = final_params.get("message_id", message_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the email. Please try again.",
|
||||
}
|
||||
|
||||
return trash_gmail_email
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_gmail_draft(
|
||||
draft_subject_or_id: str,
|
||||
body: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Gmail draft.
|
||||
|
||||
Use when the user asks to modify, edit, or add content to an existing
|
||||
email draft. This replaces the draft content with the new version.
|
||||
The user will be able to review and edit the content before it is applied.
|
||||
|
||||
If the user simply wants to "edit" a draft without specifying exact changes,
|
||||
generate the body yourself using your best understanding of the conversation
|
||||
context. The user will review and can freely edit the content in the approval
|
||||
card before confirming.
|
||||
|
||||
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
|
||||
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
|
||||
calendar events, or any other content type.
|
||||
|
||||
Args:
|
||||
draft_subject_or_id: The exact subject line of the draft to update
|
||||
(as it appears in Gmail drafts).
|
||||
body: The full updated body content for the draft. Generate this
|
||||
yourself based on the user's request and conversation context.
|
||||
to: Optional new recipient email address (keeps original if omitted).
|
||||
subject: Optional new subject line (keeps original if omitted).
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the draft subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Update the Kurseong Plan draft with the new itinerary details"
|
||||
- "Edit my draft about the project proposal and change the recipient"
|
||||
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
|
||||
"""
|
||||
logger.info(
|
||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_update",
|
||||
"action": {
|
||||
"tool": "update_gmail_draft",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", final_to_default)
|
||||
final_subject = final_params.get("subject", final_subject_default)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = final_params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return update_gmail_draft
|
||||
|
||||
|
||||
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
|
||||
"""Look up a draft's ID by its message ID via the Gmail API."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
gmail_service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||
return None
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.google_calendar.create_event import (
|
||||
create_create_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_calendar_event(
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Google Calendar.
|
||||
|
||||
Use when the user asks to schedule, create, or add a calendar event.
|
||||
Ask for event details if not provided.
|
||||
|
||||
Args:
|
||||
summary: The event title.
|
||||
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
|
||||
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
|
||||
description: Optional event description.
|
||||
location: Optional event location.
|
||||
attendees: Optional list of attendee email addresses.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Schedule a meeting with John tomorrow at 10am"
|
||||
- "Create a calendar event for the team standup"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_creation",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_start_datetime = final_params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = final_params.get("end_datetime", end_datetime)
|
||||
final_description = final_params.get("description", description)
|
||||
final_location = final_params.get("location", location)
|
||||
final_attendees = final_params.get("attendees", attendees)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the event. Please try again.",
|
||||
}
|
||||
|
||||
return create_calendar_event
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_calendar_event(
|
||||
event_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Google Calendar event.
|
||||
|
||||
Use when the user asks to delete, remove, or cancel a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to delete.
|
||||
delete_from_kb: Whether to also remove the event from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Calendar and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the team standup event"
|
||||
- "Cancel my dentist appointment on Friday"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_deletion",
|
||||
"action": {
|
||||
"tool": "delete_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return delete_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the event. Please try again.",
|
||||
}
|
||||
|
||||
return delete_calendar_event
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_calendar_event(
|
||||
event_title_or_id: str,
|
||||
new_summary: str | None = None,
|
||||
new_start_datetime: str | None = None,
|
||||
new_end_datetime: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_location: str | None = None,
|
||||
new_attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Google Calendar event.
|
||||
|
||||
Use when the user asks to modify, reschedule, or change a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to update.
|
||||
new_summary: New event title (if changing).
|
||||
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
|
||||
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
|
||||
new_description: New event description (if changing).
|
||||
new_location: New event location (if changing).
|
||||
new_attendees: New list of attendee email addresses (if changing).
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Reschedule the team standup to 3pm"
|
||||
- "Change the location of my dentist appointment"
|
||||
"""
|
||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_update",
|
||||
"action": {
|
||||
"tool": "update_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = final_params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = final_params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = final_params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = final_params.get("new_description", new_description)
|
||||
final_new_location = final_params.get("new_location", new_location)
|
||||
final_new_attendees = final_params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["start"] = {
|
||||
"dateTime": final_new_start_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_end_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["end"] = {
|
||||
"dateTime": final_new_end_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the event. Please try again.",
|
||||
}
|
||||
|
||||
return update_calendar_event
|
||||
|
|
@ -32,13 +32,16 @@ def create_create_google_drive_file_tool(
|
|||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive.
|
||||
or spreadsheet in Google Drive. The user MUST specify a topic before
|
||||
you call this tool. If the request does not contain a topic (e.g.
|
||||
"create a drive doc" or "make a Google Sheet"), ask what the file
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. For google_doc, provide markdown text.
|
||||
For google_sheet, provide CSV-formatted text.
|
||||
content: Optional initial content. Generate from the user's topic.
|
||||
For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -55,8 +58,8 @@ def create_create_google_drive_file_tool(
|
|||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
- "Create a Google Doc with today's meeting notes"
|
||||
- "Create a spreadsheet for the 2026 budget"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
|
|
@ -84,6 +87,15 @@ def create_create_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
|
@ -154,14 +166,18 @@ def create_create_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -176,8 +192,7 @@ def create_create_google_drive_file_tool(
|
|||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -191,8 +206,22 @@ def create_create_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
|
|
@ -206,22 +235,65 @@ def create_create_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def create_delete_google_drive_file_tool(
|
|||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
|
|
@ -76,6 +75,18 @@ def create_delete_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
|
|
@ -151,13 +162,17 @@ def create_delete_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -170,7 +185,23 @@ def create_delete_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
|
|
@ -178,10 +209,26 @@ def create_delete_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
|
|
|
|||
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Jira tools for creating, updating, and deleting issues."""
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_jira_issue_tool",
|
||||
"create_delete_jira_issue_tool",
|
||||
"create_update_jira_issue_tool",
|
||||
]
|
||||
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_jira_issue(
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue in Jira.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Jira issue/ticket.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key (e.g. "PROJ", "ENG").
|
||||
summary: Short, descriptive issue title.
|
||||
issue_type: Issue type (default "Task"). Others: "Bug", "Story", "Epic".
|
||||
description: Optional description body for the issue.
|
||||
priority: Optional priority name (e.g. "High", "Medium", "Low").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, issue_key, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user declined. Do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Jira accounts need re-authentication.",
|
||||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_creation",
|
||||
"action": {
|
||||
"tool": "create_jira_issue",
|
||||
"params": {
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_project_key = final_params.get("project_key", project_key)
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_issue_type = final_params.get("issue_type", issue_type)
|
||||
final_description = final_params.get("description", description)
|
||||
final_priority = final_params.get("priority", priority)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
||||
if not final_project_key:
|
||||
return {"status": "error", "message": "A project must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Jira connector found."}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
api_result = await asyncio.to_thread(
|
||||
jira_client.create_issue,
|
||||
project_key=final_project_key,
|
||||
summary=final_summary,
|
||||
issue_type=final_issue_type,
|
||||
description=final_description,
|
||||
priority=final_priority,
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_key = api_result.get("key", "")
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{issue_key}"
|
||||
if jira_history._base_url and issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=issue_key,
|
||||
issue_identifier=issue_key,
|
||||
issue_title=final_summary,
|
||||
description=final_description,
|
||||
state="To Do",
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the issue.",
|
||||
}
|
||||
|
||||
return create_jira_issue
|
||||
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Jira issue.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_deletion",
|
||||
"action": {
|
||||
"tool": "delete_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the issue.",
|
||||
}
|
||||
|
||||
return delete_jira_issue
|
||||
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
new_summary: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue.
|
||||
|
||||
Use this tool when the user asks to modify, edit, or update a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title to identify the issue.
|
||||
new_summary: Optional new title/summary for the issue.
|
||||
new_description: Optional new description.
|
||||
new_priority: Optional new priority name.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message and ask user to verify.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_update",
|
||||
"action": {
|
||||
"tool": "update_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_summary = final_params.get("new_summary", new_summary)
|
||||
final_description = final_params.get("new_description", new_description)
|
||||
final_priority = final_params.get("new_priority", new_priority)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
if final_summary:
|
||||
fields["summary"] = final_summary
|
||||
if final_description is not None:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": final_description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if final_priority:
|
||||
fields["priority"] = {"name": final_priority}
|
||||
|
||||
if not fields:
|
||||
return {"status": "error", "message": "No changes specified."}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(
|
||||
jira_client.update_issue, final_issue_key, fields
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||
if jira_history._base_url and final_issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_key,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the issue.",
|
||||
}
|
||||
|
||||
return update_jira_issue
|
||||
|
|
@ -9,6 +9,7 @@ This module provides:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
|
@ -19,7 +20,7 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -60,7 +61,7 @@ def _is_degenerate_query(query: str) -> bool:
|
|||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
document_type: str | list[str] | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
|
|
@ -83,14 +84,22 @@ async def _browse_recent_documents(
|
|||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
|
|
@ -195,10 +204,6 @@ _ALL_CONNECTORS: list[str] = [
|
|||
"CRAWLED_URL",
|
||||
"CIRCLEBACK",
|
||||
"OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
]
|
||||
|
||||
# Human-readable descriptions for each connector type
|
||||
|
|
@ -228,10 +233,6 @@ CONNECTOR_DESCRIPTIONS: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
|
||||
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
|
||||
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "Google Drive files via Composio (personal cloud storage)",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "Gmail emails via Composio (personal emails)",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events via Composio (personal calendar)",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -352,6 +353,20 @@ def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
|||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"message_id",
|
||||
"thread_id",
|
||||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
|
|
@ -480,7 +495,10 @@ def format_documents_for_context(
|
|||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
metadata_clean = {
|
||||
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
|
||||
}
|
||||
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
doc_lines: list[str] = [
|
||||
|
|
@ -617,7 +635,12 @@ async def search_knowledge_base_async(
|
|||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [c for c in connectors if c in doc_types_set]
|
||||
connectors = [
|
||||
c
|
||||
for c in connectors
|
||||
if c in doc_types_set
|
||||
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
perf.info(
|
||||
|
|
@ -654,6 +677,13 @@ async def search_knowledge_base_async(
|
|||
)
|
||||
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
||||
|
||||
expanded_browse = []
|
||||
for c in browse_connectors:
|
||||
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
|
||||
else:
|
||||
expanded_browse.append(c)
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
|
|
@ -663,7 +693,7 @@ async def search_knowledge_base_async(
|
|||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in browse_connectors
|
||||
for c in expanded_browse
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
|
|
@ -779,6 +809,10 @@ async def search_knowledge_base_async(
|
|||
|
||||
deduplicated.append(doc)
|
||||
|
||||
# Sort by RRF score so the most relevant documents from ANY connector
|
||||
# appear first, preventing budget truncation from hiding top results.
|
||||
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
|
||||
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,13 @@ def create_create_linear_issue_tool(
|
|||
"""Create a new issue in Linear.
|
||||
|
||||
Use this tool when the user explicitly asks to create, add, or file
|
||||
a new issue / ticket / task in Linear.
|
||||
a new issue / ticket / task in Linear. The user MUST describe the issue
|
||||
before you call this tool. If the request is vague, ask what the issue
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: Short, descriptive issue title.
|
||||
description: Optional markdown body for the issue.
|
||||
title: Short, descriptive issue title. Infer from the user's request.
|
||||
description: Optional markdown body for the issue. Generate from context.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -57,9 +59,9 @@ def create_create_linear_issue_tool(
|
|||
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Linear issue titled 'Fix login bug'"
|
||||
- "Add a ticket for the payment timeout problem"
|
||||
- "File an issue about the broken search feature"
|
||||
- "Create a Linear issue for the login bug"
|
||||
- "File a ticket about the payment timeout problem"
|
||||
- "Add an issue for the broken search feature"
|
||||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
|
|
@ -82,6 +84,15 @@ def create_create_linear_issue_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,12 +226,36 @@ def create_create_linear_issue_tool(
|
|||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": result.get("message"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ def create_delete_linear_issue_tool(
|
|||
- If status is "not_found", inform the user conversationally using the exact message
|
||||
provided. Do NOT treat this as an error. Simply relay the message and ask the user
|
||||
to verify the issue title or identifier, or check if it has been indexed.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Fix login bug' Linear issue"
|
||||
- "Archive ENG-42"
|
||||
|
|
@ -91,6 +90,14 @@ def create_delete_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,14 @@ def create_update_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -1,465 +0,0 @@
|
|||
"""
|
||||
Link preview tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool for fetching URL metadata (title, description,
|
||||
Open Graph image, etc.) to display rich link previews in the chat UI.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import trafilatura
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
"""Extract the domain from a URL."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc
|
||||
# Remove 'www.' prefix if present
|
||||
if domain.startswith("www."):
|
||||
domain = domain[4:]
|
||||
return domain
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def extract_og_content(html: str, property_name: str) -> str | None:
|
||||
"""Extract Open Graph meta content from HTML."""
|
||||
# Try og:property first
|
||||
pattern = rf'<meta[^>]+property=["\']og:{property_name}["\'][^>]+content=["\']([^"\']+)["\']'
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try content before property
|
||||
pattern = rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+property=["\']og:{property_name}["\']'
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_twitter_content(html: str, name: str) -> str | None:
|
||||
"""Extract Twitter Card meta content from HTML."""
|
||||
pattern = (
|
||||
rf'<meta[^>]+name=["\']twitter:{name}["\'][^>]+content=["\']([^"\']+)["\']'
|
||||
)
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try content before name
|
||||
pattern = (
|
||||
rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']twitter:{name}["\']'
|
||||
)
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_meta_description(html: str) -> str | None:
|
||||
"""Extract meta description from HTML."""
|
||||
pattern = r'<meta[^>]+name=["\']description["\'][^>]+content=["\']([^"\']+)["\']'
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try content before name
|
||||
pattern = r'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']description["\']'
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_title(html: str) -> str | None:
|
||||
"""Extract title from HTML."""
|
||||
# Try og:title first
|
||||
og_title = extract_og_content(html, "title")
|
||||
if og_title:
|
||||
return og_title
|
||||
|
||||
# Try twitter:title
|
||||
twitter_title = extract_twitter_content(html, "title")
|
||||
if twitter_title:
|
||||
return twitter_title
|
||||
|
||||
# Fall back to <title> tag
|
||||
pattern = r"<title[^>]*>([^<]+)</title>"
|
||||
match = re.search(pattern, html, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_description(html: str) -> str | None:
|
||||
"""Extract description from HTML."""
|
||||
# Try og:description first
|
||||
og_desc = extract_og_content(html, "description")
|
||||
if og_desc:
|
||||
return og_desc
|
||||
|
||||
# Try twitter:description
|
||||
twitter_desc = extract_twitter_content(html, "description")
|
||||
if twitter_desc:
|
||||
return twitter_desc
|
||||
|
||||
# Fall back to meta description
|
||||
return extract_meta_description(html)
|
||||
|
||||
|
||||
def extract_image(html: str) -> str | None:
|
||||
"""Extract image URL from HTML."""
|
||||
# Try og:image first
|
||||
og_image = extract_og_content(html, "image")
|
||||
if og_image:
|
||||
return og_image
|
||||
|
||||
# Try twitter:image
|
||||
twitter_image = extract_twitter_content(html, "image")
|
||||
if twitter_image:
|
||||
return twitter_image
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def generate_preview_id(url: str) -> str:
|
||||
"""Generate a unique ID for a link preview."""
|
||||
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
|
||||
return f"link-preview-{hash_val}"
|
||||
|
||||
|
||||
def _unescape_html(text: str) -> str:
|
||||
"""Unescape common HTML entities."""
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", '"')
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
|
||||
def _make_absolute_url(image_url: str, base_url: str) -> str:
|
||||
"""Convert a relative image URL to an absolute URL."""
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
return image_url
|
||||
if image_url.startswith("//"):
|
||||
return f"https:{image_url}"
|
||||
if image_url.startswith("/"):
|
||||
parsed = urlparse(base_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}{image_url}"
|
||||
return image_url
|
||||
|
||||
|
||||
async def fetch_with_chromium(url: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch page content using headless Chromium browser via Playwright.
|
||||
Used as a fallback when simple HTTP requests are blocked (403, etc.).
|
||||
|
||||
Runs the sync Playwright API in a thread so it works on any event
|
||||
loop, including Windows ``SelectorEventLoop``.
|
||||
|
||||
Args:
|
||||
url: URL to fetch
|
||||
|
||||
Returns:
|
||||
Dict with title, description, image, and raw_html, or None if failed
|
||||
"""
|
||||
try:
|
||||
return await asyncio.to_thread(_fetch_with_chromium_sync, url)
|
||||
except Exception as e:
|
||||
logger.error(f"[link_preview] Chromium fallback failed for {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_with_chromium_sync(url: str) -> dict[str, Any] | None:
|
||||
"""Synchronous Playwright fetch executed in a worker thread."""
|
||||
logger.info(f"[link_preview] Falling back to Chromium for {url}")
|
||||
|
||||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
with sync_playwright() as p:
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = p.chromium.launch(**launch_kwargs)
|
||||
context = browser.new_context(user_agent=user_agent)
|
||||
page = context.new_page()
|
||||
|
||||
try:
|
||||
page.goto(url, wait_until="domcontentloaded", timeout=30000)
|
||||
raw_html = page.content()
|
||||
finally:
|
||||
browser.close()
|
||||
|
||||
if not raw_html or len(raw_html.strip()) == 0:
|
||||
logger.warning(f"[link_preview] Chromium returned empty content for {url}")
|
||||
return None
|
||||
|
||||
trafilatura_metadata = trafilatura.extract_metadata(raw_html)
|
||||
|
||||
image = extract_image(raw_html)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"title": None,
|
||||
"description": None,
|
||||
"image": image,
|
||||
"raw_html": raw_html,
|
||||
}
|
||||
|
||||
if trafilatura_metadata:
|
||||
result["title"] = trafilatura_metadata.title
|
||||
result["description"] = trafilatura_metadata.description
|
||||
|
||||
if not result["title"]:
|
||||
result["title"] = extract_title(raw_html)
|
||||
if not result["description"]:
|
||||
result["description"] = extract_description(raw_html)
|
||||
|
||||
logger.info(f"[link_preview] Successfully fetched {url} via Chromium")
|
||||
return result
|
||||
|
||||
|
||||
def create_link_preview_tool():
|
||||
"""
|
||||
Factory function to create the link_preview tool.
|
||||
|
||||
Returns:
|
||||
A configured tool function for fetching link previews.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def link_preview(url: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch metadata for a URL to display a rich link preview.
|
||||
|
||||
Use this tool when the user shares a URL or asks about a specific webpage.
|
||||
This tool fetches the page's Open Graph metadata (title, description, image)
|
||||
to display a nice preview card in the chat.
|
||||
|
||||
Common triggers include:
|
||||
- User shares a URL in the chat
|
||||
- User asks "What's this link about?" or similar
|
||||
- User says "Show me a preview of this page"
|
||||
- User wants to preview an article or webpage
|
||||
|
||||
Args:
|
||||
url: The URL to fetch metadata for. Must be a valid HTTP/HTTPS URL.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- id: Unique identifier for this preview
|
||||
- assetId: The URL itself (for deduplication)
|
||||
- kind: "link" (type of media card)
|
||||
- href: The URL to open when clicked
|
||||
- title: Page title
|
||||
- description: Page description (if available)
|
||||
- thumb: Thumbnail/preview image URL (if available)
|
||||
- domain: The domain name
|
||||
- error: Error message (if fetch failed)
|
||||
"""
|
||||
preview_id = generate_preview_id(url)
|
||||
domain = extract_domain(url)
|
||||
|
||||
# Validate URL
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = f"https://{url}"
|
||||
|
||||
try:
|
||||
# Generate a random User-Agent to avoid bot detection
|
||||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
# Use residential proxy if configured
|
||||
proxy_url = get_residential_proxy_url()
|
||||
|
||||
# Use a browser-like User-Agent to fetch Open Graph metadata.
|
||||
# We're only fetching publicly available metadata (title, description, thumbnail)
|
||||
# that websites intentionally expose via OG tags for link preview purposes.
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10.0,
|
||||
follow_redirects=True,
|
||||
proxy=proxy_url,
|
||||
headers={
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Cache-Control": "no-cache",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Get content type to ensure it's HTML
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/html" not in content_type.lower():
|
||||
# Not an HTML page, return basic info
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": url.split("/")[-1] or domain,
|
||||
"description": f"File from {domain}",
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
html = response.text
|
||||
|
||||
# Extract metadata
|
||||
title = extract_title(html) or domain
|
||||
description = extract_description(html)
|
||||
image = extract_image(html)
|
||||
|
||||
# Make sure image URL is absolute
|
||||
if image:
|
||||
image = _make_absolute_url(image, url)
|
||||
|
||||
# Clean up title and description (unescape HTML entities)
|
||||
if title:
|
||||
title = _unescape_html(title)
|
||||
if description:
|
||||
description = _unescape_html(description)
|
||||
# Truncate long descriptions
|
||||
if len(description) > 200:
|
||||
description = description[:197] + "..."
|
||||
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumb": image,
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
# Timeout - try Chromium fallback
|
||||
logger.warning(
|
||||
f"[link_preview] Timeout for {url}, trying Chromium fallback"
|
||||
)
|
||||
chromium_result = await fetch_with_chromium(url)
|
||||
if chromium_result:
|
||||
title = chromium_result.get("title") or domain
|
||||
description = chromium_result.get("description")
|
||||
image = chromium_result.get("image")
|
||||
|
||||
# Clean up and truncate
|
||||
if title:
|
||||
title = _unescape_html(title)
|
||||
if description:
|
||||
description = _unescape_html(description)
|
||||
if len(description) > 200:
|
||||
description = description[:197] + "..."
|
||||
|
||||
# Make sure image URL is absolute
|
||||
if image:
|
||||
image = _make_absolute_url(image, url)
|
||||
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumb": image,
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": domain or "Link",
|
||||
"domain": domain,
|
||||
"error": "Request timed out",
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
|
||||
# For 403 (Forbidden) and similar bot-detection errors, try Chromium fallback
|
||||
if status_code in (403, 401, 406, 429):
|
||||
logger.warning(
|
||||
f"[link_preview] HTTP {status_code} for {url}, trying Chromium fallback"
|
||||
)
|
||||
chromium_result = await fetch_with_chromium(url)
|
||||
if chromium_result:
|
||||
title = chromium_result.get("title") or domain
|
||||
description = chromium_result.get("description")
|
||||
image = chromium_result.get("image")
|
||||
|
||||
# Clean up and truncate
|
||||
if title:
|
||||
title = _unescape_html(title)
|
||||
if description:
|
||||
description = _unescape_html(description)
|
||||
if len(description) > 200:
|
||||
description = description[:197] + "..."
|
||||
|
||||
# Make sure image URL is absolute
|
||||
if image:
|
||||
image = _make_absolute_url(image, url)
|
||||
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumb": image,
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": domain or "Link",
|
||||
"domain": domain,
|
||||
"error": f"HTTP {status_code}",
|
||||
}
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"[link_preview] Error fetching {url}: {error_message}")
|
||||
return {
|
||||
"id": preview_id,
|
||||
"assetId": url,
|
||||
"kind": "link",
|
||||
"href": url,
|
||||
"title": domain or "Link",
|
||||
"domain": domain,
|
||||
"error": f"Failed to fetch: {error_message[:50]}",
|
||||
}
|
||||
|
||||
return link_preview
|
||||
|
|
@ -33,17 +33,21 @@ def create_create_notion_page_tool(
|
|||
@tool
|
||||
async def create_notion_page(
|
||||
title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Notion with the given title and content.
|
||||
|
||||
Use this tool when the user asks you to create, save, or publish
|
||||
something to Notion. The page will be created in the user's
|
||||
configured Notion workspace.
|
||||
configured Notion workspace. The user MUST specify a topic before you
|
||||
call this tool. If the request does not contain a topic (e.g. "create a
|
||||
notion page"), ask what the page should be about. Never call this tool
|
||||
without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: The title of the Notion page.
|
||||
content: The markdown content for the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content for the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's topic.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -58,8 +62,8 @@ def create_create_notion_page_tool(
|
|||
and move on. Do NOT troubleshoot or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Notion page titled 'Meeting Notes' with content 'Discussed project timeline'"
|
||||
- "Save this to Notion with title 'Research Summary'"
|
||||
- "Create a Notion page about our Q2 roadmap"
|
||||
- "Save a summary of today's discussion to Notion"
|
||||
"""
|
||||
logger.info(f"create_notion_page called: title='{title}'")
|
||||
|
||||
|
|
@ -85,6 +89,15 @@ def create_create_notion_page_tool(
|
|||
"message": context["error"],
|
||||
}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Notion accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "notion",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,6 +228,34 @@ def create_create_notion_page_tool(
|
|||
logger.info(
|
||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=result.get("page_id"),
|
||||
page_title=result.get("title", final_title),
|
||||
page_url=result.get("url"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
result["message"] = result.get("message", "") + kb_message_suffix
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -95,8 +95,19 @@ def create_delete_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
connector_id_from_context = account.get("id")
|
||||
document_id = context.get("document_id")
|
||||
|
||||
logger.info(
|
||||
|
|
@ -262,6 +273,18 @@ def create_delete_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error deleting Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -33,16 +33,19 @@ def create_update_notion_page_tool(
|
|||
@tool
|
||||
async def update_notion_page(
|
||||
page_title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Notion page by appending new content.
|
||||
|
||||
Use this tool when the user asks you to add content to, modify, or update
|
||||
a Notion page. The new content will be appended to the existing page content.
|
||||
The user MUST specify what to add before you call this tool. If the
|
||||
request is vague, ask what content they want added.
|
||||
|
||||
Args:
|
||||
page_title: The title of the Notion page to update.
|
||||
content: The markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's request.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -60,10 +63,9 @@ def create_update_notion_page_tool(
|
|||
Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]"
|
||||
Do NOT treat this as an error. Do NOT invent information. Simply relay the message and
|
||||
ask the user to verify the page title or check if it's been indexed.
|
||||
|
||||
Examples:
|
||||
- "Add 'New meeting notes from today' to the 'Meeting Notes' Notion page"
|
||||
- "Append the following to the 'Project Plan' Notion page: '# Status Update\n\nCompleted phase 1'"
|
||||
- "Add today's meeting notes to the 'Meeting Notes' Notion page"
|
||||
- "Update the 'Project Plan' page with a status update on phase 1"
|
||||
"""
|
||||
logger.info(
|
||||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||
|
|
@ -107,6 +109,17 @@ def create_update_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
document_id = context.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
|
@ -261,6 +274,18 @@ def create_update_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error updating Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -45,19 +45,38 @@ from langchain_core.tools import BaseTool
|
|||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .display_image import create_display_image_tool
|
||||
from .confluence import (
|
||||
create_create_confluence_page_tool,
|
||||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
create_delete_linear_issue_tool,
|
||||
create_update_linear_issue_tool,
|
||||
)
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .notion import (
|
||||
create_create_notion_page_tool,
|
||||
|
|
@ -166,20 +185,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
# are optional — when missing, source_strategy="kb_search" degrades
|
||||
# gracefully to "provided"
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
name="link_preview",
|
||||
description="Fetch metadata for a URL to display a rich preview card",
|
||||
factory=lambda deps: create_link_preview_tool(),
|
||||
requires=[],
|
||||
),
|
||||
# Display image tool - shows images in the chat
|
||||
ToolDefinition(
|
||||
name="display_image",
|
||||
description="Display an image in the chat with metadata",
|
||||
factory=lambda deps: create_display_image_tool(),
|
||||
requires=[],
|
||||
),
|
||||
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
|
||||
ToolDefinition(
|
||||
name="generate_image",
|
||||
|
|
@ -257,7 +262,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues (WIP - hidden from UI)
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
|
|
@ -268,8 +274,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
|
|
@ -280,8 +284,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
|
|
@ -292,11 +294,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages (WIP - hidden from UI)
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_notion_page",
|
||||
|
|
@ -307,8 +308,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -319,8 +318,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -331,11 +328,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files (WIP - hidden from UI)
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
# Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_google_drive_file",
|
||||
|
|
@ -346,8 +342,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -358,8 +352,152 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
factory=lambda deps: create_create_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
description="Update an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_update_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
description="Delete an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_delete_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
factory=lambda deps: create_create_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
description="Send an email via Gmail",
|
||||
factory=lambda deps: create_send_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
description="Move an indexed email to trash in Gmail",
|
||||
factory=lambda deps: create_trash_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
description="Update an existing Gmail draft",
|
||||
factory=lambda deps: create_update_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_confluence_page",
|
||||
description="Create a new page in the user's Confluence space",
|
||||
factory=lambda deps: create_create_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
description="Update an existing indexed Confluence page",
|
||||
factory=lambda deps: create_update_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
description="Delete an existing indexed Confluence page",
|
||||
factory=lambda deps: create_delete_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -413,7 +551,7 @@ def build_tools(
|
|||
tools = build_tools(deps)
|
||||
|
||||
# Use only specific tools
|
||||
tools = build_tools(deps, enabled_tools=["search_knowledge_base", "link_preview"])
|
||||
tools = build_tools(deps, enabled_tools=["search_knowledge_base"])
|
||||
|
||||
# Use defaults but disable podcast
|
||||
tools = build_tools(deps, disabled_tools=["generate_podcast"])
|
||||
|
|
|
|||
|
|
@ -1,719 +0,0 @@
|
|||
"""
|
||||
Composio Gmail Connector Module.
|
||||
|
||||
Provides Gmail specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify as md
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGmailConnector(ComposioConnector):
|
||||
"""
|
||||
Gmail specific Composio connector.
|
||||
|
||||
Provides methods for listing messages, getting message details, and formatting
|
||||
Gmail messages from Gmail via Composio.
|
||||
"""
|
||||
|
||||
async def list_gmail_messages(
|
||||
self,
|
||||
query: str = "",
|
||||
max_results: int = 50,
|
||||
page_token: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None, int | None, str | None]:
|
||||
"""
|
||||
List Gmail messages via Composio with pagination support.
|
||||
|
||||
Args:
|
||||
query: Gmail search query.
|
||||
max_results: Maximum number of messages per page (default: 50).
|
||||
page_token: Optional pagination token for next page.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages list, next_page_token, result_size_estimate, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], None, None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_messages(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
async def get_gmail_message_detail(
|
||||
self, message_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Get full details of a Gmail message via Composio.
|
||||
|
||||
Args:
|
||||
message_id: Gmail message ID.
|
||||
|
||||
Returns:
|
||||
Tuple of (message details, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_message_detail(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_markdown(html: str) -> str:
|
||||
"""Convert HTML (especially email layouts with nested tables) to clean markdown."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for tag in soup.find_all(["style", "script", "img"]):
|
||||
tag.decompose()
|
||||
for tag in soup.find_all(
|
||||
["table", "thead", "tbody", "tfoot", "tr", "td", "th"]
|
||||
):
|
||||
tag.unwrap()
|
||||
return md(str(soup)).strip()
|
||||
|
||||
def format_gmail_message_to_markdown(self, message: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Gmail message to markdown.
|
||||
|
||||
Args:
|
||||
message: Message object from Composio's GMAIL_FETCH_EMAILS response.
|
||||
Composio structure: messageId, messageText, messageTimestamp,
|
||||
payload.headers, labelIds, attachmentList
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase)
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
label_ids = message.get("labelIds", [])
|
||||
|
||||
# Extract headers from payload
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
# Parse headers into a dict
|
||||
header_dict = {}
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
header_dict[name] = value
|
||||
|
||||
# Extract key information
|
||||
subject = header_dict.get("subject", "No Subject")
|
||||
from_email = header_dict.get("from", "Unknown Sender")
|
||||
to_email = header_dict.get("to", "Unknown Recipient")
|
||||
# Composio provides messageTimestamp directly
|
||||
date_str = message.get("messageTimestamp", "") or header_dict.get(
|
||||
"date", "Unknown Date"
|
||||
)
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {subject}\n\n"
|
||||
markdown_content += f"**From:** {from_email}\n"
|
||||
markdown_content += f"**To:** {to_email}\n"
|
||||
markdown_content += f"**Date:** {date_str}\n"
|
||||
|
||||
if label_ids:
|
||||
markdown_content += f"**Labels:** {', '.join(label_ids)}\n"
|
||||
|
||||
markdown_content += "\n---\n\n"
|
||||
|
||||
# Composio provides full message text in 'messageText' which is often raw HTML
|
||||
message_text = message.get("messageText", "")
|
||||
if message_text:
|
||||
message_text = self._html_to_markdown(message_text)
|
||||
markdown_content += f"## Content\n\n{message_text}\n\n"
|
||||
else:
|
||||
# Fallback to snippet if no messageText
|
||||
snippet = message.get("snippet", "")
|
||||
if snippet:
|
||||
markdown_content += f"## Preview\n\n{snippet}\n\n"
|
||||
|
||||
# Add attachment info if present
|
||||
attachments = message.get("attachmentList", [])
|
||||
if attachments:
|
||||
markdown_content += "## Attachments\n\n"
|
||||
for att in attachments:
|
||||
att_name = att.get("filename", att.get("name", "Unknown"))
|
||||
markdown_content += f"- {att_name}\n"
|
||||
markdown_content += "\n"
|
||||
|
||||
# Add message metadata
|
||||
markdown_content += "## Message Details\n\n"
|
||||
markdown_content += f"- **Message ID:** {message_id}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting message to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def _analyze_gmail_messages_phase1(
|
||||
session: AsyncSession,
|
||||
messages: list[dict[str, Any]],
|
||||
composio_connector: ComposioGmailConnector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""
|
||||
Phase 1: Analyze all messages, create pending documents.
|
||||
Makes ALL documents visible in the UI immediately with pending status.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages_to_process, documents_skipped, duplicate_content_count)
|
||||
"""
|
||||
messages_to_process = []
|
||||
documents_skipped = 0
|
||||
duplicate_content_count = 0
|
||||
|
||||
for message in messages:
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase), not 'id'
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
if not message_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Extract message info from Composio response
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
subject = "No Subject"
|
||||
sender = "Unknown Sender"
|
||||
date_str = message.get("messageTimestamp", "Unknown Date")
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name == "subject":
|
||||
subject = value
|
||||
elif name == "from":
|
||||
sender = value
|
||||
elif name == "date":
|
||||
date_str = value
|
||||
|
||||
# Format to markdown using the full message data
|
||||
markdown_content = composio_connector.format_gmail_message_to_markdown(
|
||||
message
|
||||
)
|
||||
|
||||
# Check for empty content
|
||||
if not markdown_content.strip():
|
||||
logger.warning(f"Skipping Gmail message with no content: {subject}")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"gmail_{message_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Get label IDs and thread_id from Composio response
|
||||
label_ids = message.get("labelIds", [])
|
||||
thread_id = message.get("threadId", "") or message.get("thread_id", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Message {subject} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=subject,
|
||||
document_type=DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"]),
|
||||
document_metadata={
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"labels": label_ids,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "gmail",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
return messages_to_process, documents_skipped, duplicate_content_count
|
||||
|
||||
|
||||
async def _process_gmail_messages_phase2(
|
||||
session: AsyncSession,
|
||||
messages_to_process: list[dict[str, Any]],
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Phase 2: Process each document one by one.
|
||||
Each document transitions: pending → processing → ready/failed
|
||||
|
||||
Returns:
|
||||
Tuple of (documents_indexed, documents_failed)
|
||||
"""
|
||||
documents_indexed = 0
|
||||
documents_failed = 0
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
for item in messages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"document_type": "Gmail Message (Composio)",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
item["markdown_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["subject"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"date": item["date_str"],
|
||||
"labels": item["label_ids"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Gmail messages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
return documents_indexed, documents_failed
|
||||
|
||||
|
||||
async def index_composio_gmail(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 1000,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Gmail messages via Composio with real-time document status updates."""
|
||||
try:
|
||||
composio_connector = ComposioGmailConnector(session, connector_id)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
if start_date is not None and end_date is not None:
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build query with date range
|
||||
query_parts = []
|
||||
if start_date_str:
|
||||
query_parts.append(f"after:{start_date_str.replace('-', '/')}")
|
||||
if end_date_str:
|
||||
query_parts.append(f"before:{end_date_str.replace('-', '/')}")
|
||||
query = " ".join(query_parts) if query_parts else ""
|
||||
|
||||
logger.info(
|
||||
f"Gmail query for connector {connector_id}: '{query}' "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Gmail messages via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_messages"},
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# FETCH ALL MESSAGES FIRST
|
||||
# =======================================================================
|
||||
batch_size = 50
|
||||
page_token = None
|
||||
all_messages = []
|
||||
result_size_estimate = None
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
while len(all_messages) < max_items:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(len(all_messages))
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
remaining = max_items - len(all_messages)
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
|
||||
(
|
||||
messages,
|
||||
next_token,
|
||||
result_size_estimate_batch,
|
||||
error,
|
||||
) = await composio_connector.list_gmail_messages(
|
||||
query=query,
|
||||
max_results=current_batch_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Gmail messages: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Gmail messages: {error}"
|
||||
|
||||
if not messages:
|
||||
break
|
||||
|
||||
if result_size_estimate is None and result_size_estimate_batch is not None:
|
||||
result_size_estimate = result_size_estimate_batch
|
||||
logger.info(
|
||||
f"Gmail API estimated {result_size_estimate} total messages for query: '{query}'"
|
||||
)
|
||||
|
||||
all_messages.extend(messages)
|
||||
logger.info(
|
||||
f"Fetched {len(messages)} messages (total: {len(all_messages)})"
|
||||
)
|
||||
|
||||
if not next_token or len(messages) < current_batch_size:
|
||||
break
|
||||
|
||||
page_token = next_token
|
||||
|
||||
if not all_messages:
|
||||
success_msg = "No Gmail messages found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"messages_count": 0}
|
||||
)
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(all_messages)} Gmail messages to index via Composio")
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all messages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 1: Creating pending documents for {len(all_messages)} messages",
|
||||
{"stage": "phase1_pending"},
|
||||
)
|
||||
|
||||
(
|
||||
messages_to_process,
|
||||
documents_skipped,
|
||||
duplicate_content_count,
|
||||
) = await _analyze_gmail_messages_phase1(
|
||||
session=session,
|
||||
messages=all_messages,
|
||||
composio_connector=composio_connector,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
new_documents_count = len([m for m in messages_to_process if m["is_new"]])
|
||||
if new_documents_count > 0:
|
||||
logger.info(f"Phase 1: Committing {new_documents_count} pending documents")
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 2: Processing {len(messages_to_process)} documents",
|
||||
{"stage": "phase2_processing"},
|
||||
)
|
||||
|
||||
documents_indexed, documents_failed = await _process_gmail_messages_phase2(
|
||||
session=session,
|
||||
messages_to_process=messages_to_process,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
# CRITICAL: Always update timestamp so Electric SQL syncs
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted
|
||||
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Gmail document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Gmail indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Gmail indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index Gmail via Composio: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Gmail via Composio: {e!s}"
|
||||
|
|
@ -1,566 +0,0 @@
|
|||
"""
|
||||
Composio Google Calendar Connector Module.
|
||||
|
||||
Provides Google Calendar specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGoogleCalendarConnector(ComposioConnector):
|
||||
"""
|
||||
Google Calendar specific Composio connector.
|
||||
|
||||
Provides methods for listing calendar events and formatting them from
|
||||
Google Calendar via Composio.
|
||||
"""
|
||||
|
||||
async def list_calendar_events(
|
||||
self,
|
||||
time_min: str | None = None,
|
||||
time_max: str | None = None,
|
||||
max_results: int = 250,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
List Google Calendar events via Composio.
|
||||
|
||||
Args:
|
||||
time_min: Start time (RFC3339 format).
|
||||
time_max: End time (RFC3339 format).
|
||||
max_results: Maximum number of events.
|
||||
|
||||
Returns:
|
||||
Tuple of (events list, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_calendar_events(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
def format_calendar_event_to_markdown(self, event: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Google Calendar event to markdown.
|
||||
|
||||
Args:
|
||||
event: Event object from Google Calendar API.
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Extract basic event information
|
||||
summary = event.get("summary", "No Title")
|
||||
description = event.get("description", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
# Extract start and end times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
|
||||
# Format times for display
|
||||
def format_time(time_str: str) -> str:
|
||||
if not time_str:
|
||||
return "Unknown"
|
||||
try:
|
||||
if "T" in time_str:
|
||||
dt = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
return time_str
|
||||
except Exception:
|
||||
return time_str
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Extract attendees
|
||||
attendees = event.get("attendees", [])
|
||||
attendee_list = []
|
||||
for attendee in attendees:
|
||||
email = attendee.get("email", "")
|
||||
display_name = attendee.get("displayName", email)
|
||||
response_status = attendee.get("responseStatus", "")
|
||||
attendee_list.append(f"- {display_name} ({response_status})")
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {summary}\n\n"
|
||||
markdown_content += f"**Start:** {start_formatted}\n"
|
||||
markdown_content += f"**End:** {end_formatted}\n"
|
||||
|
||||
if location:
|
||||
markdown_content += f"**Location:** {location}\n"
|
||||
|
||||
markdown_content += "\n"
|
||||
|
||||
if description:
|
||||
markdown_content += f"## Description\n\n{description}\n\n"
|
||||
|
||||
if attendee_list:
|
||||
markdown_content += "## Attendees\n\n"
|
||||
markdown_content += "\n".join(attendee_list)
|
||||
markdown_content += "\n\n"
|
||||
|
||||
# Add event metadata
|
||||
markdown_content += "## Event Details\n\n"
|
||||
markdown_content += f"- **Event ID:** {event.get('id', 'Unknown')}\n"
|
||||
markdown_content += f"- **Created:** {event.get('created', 'Unknown')}\n"
|
||||
markdown_content += f"- **Updated:** {event.get('updated', 'Unknown')}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting event to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def index_composio_google_calendar(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 2500,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Google Calendar events via Composio."""
|
||||
try:
|
||||
composio_connector = ComposioGoogleCalendarConnector(session, connector_id)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_events"},
|
||||
)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
# This ensures user-selected dates are respected (matching non-Composio Calendar connector behavior)
|
||||
if start_date is not None and end_date is not None:
|
||||
# User provided both dates - use them directly
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
# Calculate date range with defaults (uses last_indexed_at or 365 days back)
|
||||
# This ensures indexing works even when user doesn't specify dates
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build time range for API call
|
||||
time_min = f"{start_date_str}T00:00:00Z"
|
||||
time_max = f"{end_date_str}T23:59:59Z"
|
||||
|
||||
logger.info(
|
||||
f"Google Calendar query for connector {connector_id}: "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
events, error = await composio_connector.list_calendar_events(
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_items,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Calendar events: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Calendar events: {error}"
|
||||
|
||||
if not events:
|
||||
success_msg = "No Google Calendar events found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"events_count": 0}
|
||||
)
|
||||
# CRITICAL: Update timestamp even when no events found so Electric SQL syncs and UI shows indexed status
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(events)} Google Calendar events to index via Composio")
|
||||
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track events that failed processing
|
||||
duplicate_content_count = (
|
||||
0 # Track events skipped due to duplicate content_hash
|
||||
)
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all events, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
events_to_process = [] # List of dicts with document and event data
|
||||
new_documents_created = False
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# Handle both standard Google API and potential Composio variations
|
||||
event_id = event.get("id", "") or event.get("eventId", "")
|
||||
summary = (
|
||||
event.get("summary", "") or event.get("title", "") or "No Title"
|
||||
)
|
||||
|
||||
if not event_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format to markdown
|
||||
markdown_content = composio_connector.format_calendar_event_to_markdown(
|
||||
event
|
||||
)
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"calendar_{event_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Extract event times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Event {summary} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=summary,
|
||||
document_type=DocumentType(
|
||||
TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"]
|
||||
),
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "googlecalendar",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
|
||||
|
||||
for item in events_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"document_type": "Google Calendar Event (Composio)",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["markdown_content"],
|
||||
user_llm,
|
||||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["summary"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"end_time": item["end_time"],
|
||||
"location": item["location"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Electric SQL syncs
|
||||
# This ensures the UI shows "Last indexed" instead of "Never indexed"
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
# This matches the pattern used in non-Composio Gmail indexer
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Google Calendar events processed"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Google Calendar document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same event was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Google Calendar indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Google Calendar indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True
|
||||
)
|
||||
return 0, f"Failed to index Google Calendar via Composio: {e!s}"
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.confluence_connector import ConfluenceConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.confluence_add_connector_route import refresh_confluence_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -190,7 +189,11 @@ class ConfluenceHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.confluence_add_connector_route import (
|
||||
refresh_confluence_token,
|
||||
)
|
||||
|
||||
connector = await refresh_confluence_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
@ -341,6 +344,61 @@ class ConfluenceHistoryConnector:
|
|||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def _make_api_request_with_method(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "GET",
|
||||
json_payload: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Make a request to the Confluence API with a specified HTTP method."""
|
||||
if not self._use_oauth:
|
||||
raise ValueError("Write operations require OAuth authentication")
|
||||
|
||||
token = await self._get_valid_token()
|
||||
base_url = await self._get_base_url()
|
||||
http_client = await self._get_client()
|
||||
|
||||
url = f"{base_url}/wiki/api/v2/{endpoint}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = await http_client.post(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = await http_client.put(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = await http_client.delete(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await http_client.get(url, headers=headers, params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = {
|
||||
"status_code": e.response.status_code,
|
||||
"url": str(e.request.url),
|
||||
"response_text": e.response.text,
|
||||
}
|
||||
logger.error(f"Confluence API HTTP error: {error_detail}")
|
||||
raise Exception(
|
||||
f"Confluence API request failed (HTTP {e.response.status_code}): {e.response.text}"
|
||||
) from e
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def get_all_spaces(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all spaces from Confluence.
|
||||
|
|
@ -593,6 +651,65 @@ class ConfluenceHistoryConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching pages: {e!s}"
|
||||
|
||||
async def get_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Fetch a single page by ID with body content."""
|
||||
return await self._make_api_request(
|
||||
f"pages/{page_id}", params={"body-format": "storage"}
|
||||
)
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
space_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
parent_page_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Confluence page."""
|
||||
payload: dict[str, Any] = {
|
||||
"spaceId": space_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
if parent_page_id:
|
||||
payload["parentId"] = parent_page_id
|
||||
return await self._make_api_request_with_method(
|
||||
"pages", method="POST", json_payload=payload
|
||||
)
|
||||
|
||||
async def update_page(
|
||||
self,
|
||||
page_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
version_number: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page (requires version number)."""
|
||||
payload: dict[str, Any] = {
|
||||
"id": page_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"version": {
|
||||
"number": version_number,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="PUT", json_payload=payload
|
||||
)
|
||||
|
||||
async def delete_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Delete a Confluence page."""
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="DELETE"
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client connection."""
|
||||
if self._http_client:
|
||||
|
|
|
|||
|
|
@ -52,44 +52,39 @@ class GoogleCalendarConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -110,7 +105,6 @@ class GoogleCalendarConnector:
|
|||
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
|
||||
)
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -119,7 +113,6 @@ class GoogleCalendarConnector:
|
|||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Encrypt sensitive fields
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
|
|
@ -143,7 +136,6 @@ class GoogleCalendarConnector:
|
|||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
|
|
@ -15,16 +16,24 @@ from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
|||
class GoogleDriveClient:
|
||||
"""Client for Google Drive API operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
credentials: "Credentials | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize Google Drive client.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Drive connector
|
||||
credentials: Pre-built credentials (e.g. from Composio). If None,
|
||||
credentials are loaded from the DB connector config.
|
||||
"""
|
||||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self.service = None
|
||||
|
||||
async def get_service(self):
|
||||
|
|
@ -41,7 +50,12 @@ class GoogleDriveClient:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = await get_valid_credentials(self.session, self.connector_id)
|
||||
if self._credentials:
|
||||
credentials = self._credentials
|
||||
else:
|
||||
credentials = await get_valid_credentials(
|
||||
self.session, self.connector_id
|
||||
)
|
||||
self.service = build("drive", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ async def download_and_process_file(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: Log,
|
||||
connector_id: int | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[Any, str | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Download Google Drive file and process using Surfsense file processors.
|
||||
|
|
@ -95,6 +96,7 @@ async def download_and_process_file(
|
|||
},
|
||||
}
|
||||
# Include connector_id for de-indexing support
|
||||
connector_info["enable_summary"] = enable_summary
|
||||
if connector_id is not None:
|
||||
connector_info["connector_id"] = connector_id
|
||||
|
||||
|
|
|
|||
|
|
@ -81,44 +81,39 @@ class GoogleGmailConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -138,12 +133,38 @@ class GoogleGmailConnector:
|
|||
raise RuntimeError(
|
||||
"GMAIL connector not found; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
creds_dict = json.loads(self._credentials.to_json())
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
)
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
)
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
connector.config = creds_dict
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -167,14 +167,23 @@ class JiraConnector:
|
|||
# Use direct base URL (works for both OAuth and legacy)
|
||||
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
|
||||
|
||||
if method.upper() == "POST":
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = requests.post(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = requests.put(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = requests.delete(url, headers=headers, params=params, timeout=500)
|
||||
else:
|
||||
response = requests.get(url, headers=headers, params=params, timeout=500)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.status_code in (200, 201, 204):
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
|
|
@ -352,6 +361,91 @@ class JiraConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
def get_myself(self) -> dict[str, Any]:
|
||||
"""Fetch the current user's profile (health check)."""
|
||||
return self.make_api_request("myself")
|
||||
|
||||
def get_projects(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all projects the user has access to."""
|
||||
result = self.make_api_request("project/search")
|
||||
return result.get("values", [])
|
||||
|
||||
def get_issue_types(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all issue types."""
|
||||
return self.make_api_request("issuetype")
|
||||
|
||||
def get_priorities(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all priority levels."""
|
||||
return self.make_api_request("priority")
|
||||
|
||||
def get_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Fetch a single issue by ID or key."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}")
|
||||
|
||||
def create_issue(
|
||||
self,
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
assignee_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Jira issue."""
|
||||
fields: dict[str, Any] = {
|
||||
"project": {"key": project_key},
|
||||
"summary": summary,
|
||||
"issuetype": {"name": issue_type},
|
||||
}
|
||||
if description:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if priority:
|
||||
fields["priority"] = {"name": priority}
|
||||
if assignee_id:
|
||||
fields["assignee"] = {"accountId": assignee_id}
|
||||
|
||||
return self.make_api_request(
|
||||
"issue", method="POST", json_payload={"fields": fields}
|
||||
)
|
||||
|
||||
def update_issue(
|
||||
self, issue_id_or_key: str, fields: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue fields."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}",
|
||||
method="PUT",
|
||||
json_payload={"fields": fields},
|
||||
)
|
||||
|
||||
def delete_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Delete a Jira issue."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}", method="DELETE")
|
||||
|
||||
def get_transitions(self, issue_id_or_key: str) -> list[dict[str, Any]]:
|
||||
"""Get available transitions for an issue (for status changes)."""
|
||||
result = self.make_api_request(f"issue/{issue_id_or_key}/transitions")
|
||||
return result.get("transitions", [])
|
||||
|
||||
def transition_issue(
|
||||
self, issue_id_or_key: str, transition_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Transition an issue to a new status."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}/transitions",
|
||||
method="POST",
|
||||
json_payload={"transition": {"id": transition_id}},
|
||||
)
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.jira_connector import JiraConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -184,7 +183,9 @@ class JiraHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
|
||||
connector = await refresh_jira_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from notion_client import AsyncClient
|
||||
from notion_client.errors import APIResponseError
|
||||
from notion_markdown import to_notion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -834,106 +834,8 @@ class NotionHistoryConnector:
|
|||
return None
|
||||
|
||||
def _markdown_to_blocks(self, markdown: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert markdown content to Notion blocks.
|
||||
|
||||
This is a simple converter that handles basic markdown.
|
||||
For more complex markdown, consider using a proper markdown parser.
|
||||
|
||||
Args:
|
||||
markdown: Markdown content
|
||||
|
||||
Returns:
|
||||
List of Notion block objects
|
||||
"""
|
||||
blocks = []
|
||||
lines = markdown.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Heading 1
|
||||
if line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 2
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 3
|
||||
elif line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet list
|
||||
elif line.startswith("- ") or line.startswith("* "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif match := re.match(r"^(\d+)\.\s+(.*)$", line):
|
||||
content = match.group(2) # Extract text after "number. "
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": content}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Regular paragraph
|
||||
else:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [{"type": "text", "text": {"content": line}}]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return blocks
|
||||
"""Convert markdown content to Notion blocks using notion-markdown."""
|
||||
return to_notion(markdown)
|
||||
|
||||
async def create_page(
|
||||
self, title: str, content: str, parent_page_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -63,6 +63,16 @@ class DocumentType(StrEnum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
# Native Google document types → their legacy Composio equivalents.
|
||||
# Old documents may still carry the Composio type until they are re-indexed;
|
||||
# search, browse, and indexing must transparently handle both.
|
||||
NATIVE_TO_LEGACY_DOCTYPE: dict[str, str] = {
|
||||
"GOOGLE_DRIVE_FILE": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
|
||||
class SearchSourceConnectorType(StrEnum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
|
|
@ -712,7 +722,7 @@ class ChatComment(BaseModel, TimestampMixin):
|
|||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Denormalized thread_id for efficient Electric SQL subscriptions (one per thread)
|
||||
# Denormalized thread_id for efficient Zero subscriptions (one per thread)
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
|
|
@ -782,7 +792,7 @@ class ChatCommentMention(BaseModel, TimestampMixin):
|
|||
class ChatSessionState(BaseModel):
|
||||
"""
|
||||
Tracks real-time session state for shared chat collaboration.
|
||||
One record per thread, synced via Electric SQL.
|
||||
One record per thread, synced via Zero.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_session_state"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -157,7 +158,7 @@ class ChucksHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -217,18 +218,24 @@ class ChucksHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -149,7 +150,7 @@ class DocumentHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -197,18 +198,24 @@ class DocumentHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ router.include_router(model_list_router) # Dynamic LLM model catalogue from Ope
|
|||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Electric SQL sync
|
||||
router.include_router(notifications_router) # Notifications with Zero sync
|
||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ async def airtable_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=airtable_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -316,7 +316,7 @@ async def airtable_callback(
|
|||
f"Duplicate Airtable connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=airtable-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=airtable-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -348,7 +348,7 @@ async def airtable_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ async def clickup_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=clickup_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=clickup_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -326,7 +326,7 @@ async def clickup_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=clickup-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=clickup-connector"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ async def composio_callback(
|
|||
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=composio_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=composio_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -263,6 +263,15 @@ async def composio_callback(
|
|||
logger.info(
|
||||
f"Successfully got connected_account_id: {final_connected_account_id}"
|
||||
)
|
||||
# Wait for Composio to finish exchanging the auth code for tokens.
|
||||
try:
|
||||
service.wait_for_connection(final_connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for {final_connected_account_id}, "
|
||||
"proceeding anyway",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Build entity_id for Composio API calls (same format as used in initiate)
|
||||
entity_id = f"surfsense_{user_id}"
|
||||
|
|
@ -370,7 +379,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}"
|
||||
)
|
||||
|
||||
# This is a NEW account - create a new connector
|
||||
|
|
@ -399,7 +408,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
@ -425,6 +434,211 @@ async def composio_callback(
|
|||
) from e
|
||||
|
||||
|
||||
COMPOSIO_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth")
|
||||
async def reauth_composio_connector(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Initiate Composio re-authentication for an expired connected account.
|
||||
|
||||
Uses Composio's refresh API so the same connected_account_id stays valid
|
||||
after the user completes the OAuth flow again.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID the connector belongs to
|
||||
connector_id: ID of the existing Composio connector to re-authenticate
|
||||
return_url: Optional frontend path to redirect to after completion
|
||||
"""
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Composio integration is not enabled."
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type.in_(COMPOSIO_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Composio connector not found or access denied",
|
||||
)
|
||||
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Composio connected account ID not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Build callback URL with secure state
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
toolkit_id=connector.config.get("toolkit_id", ""),
|
||||
connector_id=connector_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
callback_base = config.COMPOSIO_REDIRECT_URI
|
||||
if not callback_base:
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
callback_base = (
|
||||
f"{backend_url}/api/v1/auth/composio/connector/reauth/callback"
|
||||
)
|
||||
else:
|
||||
# Replace the normal callback path with the reauth one
|
||||
callback_base = callback_base.replace(
|
||||
"/auth/composio/connector/callback",
|
||||
"/auth/composio/connector/reauth/callback",
|
||||
)
|
||||
|
||||
callback_url = f"{callback_base}?state={state_encoded}"
|
||||
|
||||
service = ComposioService()
|
||||
refresh_result = service.refresh_connected_account(
|
||||
connected_account_id=connected_account_id,
|
||||
redirect_url=callback_url,
|
||||
)
|
||||
|
||||
if refresh_result["redirect_url"] is None:
|
||||
# Token refreshed server-side; clear auth_expired immediately
|
||||
if connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Composio account {connected_account_id} refreshed server-side (no redirect needed)"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Authentication refreshed successfully.",
|
||||
}
|
||||
|
||||
logger.info(f"Initiating Composio re-auth for connector {connector_id}")
|
||||
return {"auth_url": refresh_result["redirect_url"]}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Composio re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth/callback")
|
||||
async def composio_reauth_callback(
|
||||
request: Request,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Handle Composio re-authentication callback.
|
||||
|
||||
Clears the auth_expired flag and redirects the user back to the frontend.
|
||||
The connected_account_id has not changed — Composio refreshed it in place.
|
||||
"""
|
||||
try:
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
return_url = data.get("return_url")
|
||||
|
||||
if not reauth_connector_id:
|
||||
raise HTTPException(status_code=400, detail="Missing connector_id in state")
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth callback",
|
||||
)
|
||||
|
||||
# Wait for Composio to finish processing new tokens before proceeding.
|
||||
# Without this, get_access_token() may return stale credentials.
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if connected_account_id:
|
||||
try:
|
||||
service = ComposioService()
|
||||
service.wait_for_connection(connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for connector {reauth_connector_id}, "
|
||||
"proceeding anyway — tokens may not be ready yet",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Clear auth_expired flag
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Composio re-auth completed for connector {reauth_connector_id}")
|
||||
|
||||
if return_url and return_url.startswith("/"):
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{return_url}")
|
||||
|
||||
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
|
||||
connector.config.get("toolkit_id", ""), "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={reauth_connector_id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Composio reauth callback: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/composio-drive/folders")
|
||||
async def list_composio_drive_folders(
|
||||
connector_id: int,
|
||||
|
|
@ -433,31 +647,23 @@ async def list_composio_drive_folders(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List folders AND files in user's Google Drive via Composio with hierarchical support.
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
This is called at index time from the manage connector page to display
|
||||
the complete file system (folders and files). Only folders are selectable.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the Composio Google Drive connector
|
||||
parent_id: Optional parent folder ID to list contents (None for root)
|
||||
|
||||
Returns:
|
||||
JSON with list of items: {
|
||||
"items": [
|
||||
{"id": str, "name": str, "mimeType": str, "isFolder": bool, ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Composio integration is not enabled.",
|
||||
)
|
||||
|
||||
connector = None
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
|
|
@ -474,7 +680,6 @@ async def list_composio_drive_folders(
|
|||
detail="Composio Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
# Get Composio connected account ID from config
|
||||
composio_connected_account_id = connector.config.get(
|
||||
"composio_connected_account_id"
|
||||
)
|
||||
|
|
@ -484,63 +689,43 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Initialize Composio service and fetch files
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
|
||||
# Fetch files/folders from Composio Google Drive
|
||||
files, _next_token, error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_size=100,
|
||||
)
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
logger.error(f"Failed to list Composio Drive files: {error}")
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Transform files to match the expected format with isFolder field
|
||||
items = []
|
||||
for file_info in files:
|
||||
file_id = file_info.get("id", "") or file_info.get("fileId", "")
|
||||
file_name = (
|
||||
file_info.get("name", "") or file_info.get("fileName", "") or "Untitled"
|
||||
)
|
||||
mime_type = file_info.get("mimeType", "") or file_info.get("mime_type", "")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
|
||||
is_folder = mime_type == "application/vnd.google-apps.folder"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"name": file_name,
|
||||
"mimeType": mime_type,
|
||||
"isFolder": is_folder,
|
||||
"parents": file_info.get("parents", []),
|
||||
"size": file_info.get("size"),
|
||||
"iconLink": file_info.get("iconLink"),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort: folders first, then files, both alphabetically
|
||||
folders = sorted(
|
||||
[item for item in items if item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
files_list = sorted(
|
||||
[item for item in items if not item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
items = folders + files_list
|
||||
|
||||
folder_count = len(folders)
|
||||
file_count = len(files_list)
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
logger.info(
|
||||
f"Listed {len(items)} total items ({folder_count} folders, {file_count} files) for Composio connector {connector_id}"
|
||||
|
|
@ -553,6 +738,31 @@ async def list_composio_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Composio Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
or "401" in str(e)
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ SCOPES = [
|
|||
"read:space:confluence",
|
||||
"read:page:confluence",
|
||||
"read:comment:confluence",
|
||||
"write:page:confluence", # Required for creating/updating pages
|
||||
"delete:page:confluence", # Required for deleting pages
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -170,7 +172,7 @@ async def confluence_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=confluence_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=confluence_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -196,6 +198,8 @@ async def confluence_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.CONFLUENCE_REDIRECT_URI:
|
||||
|
|
@ -292,6 +296,46 @@ async def confluence_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Confluence connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR, connector_config
|
||||
|
|
@ -310,7 +354,7 @@ async def confluence_callback(
|
|||
f"Duplicate Confluence connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=confluence-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -341,7 +385,7 @@ async def confluence_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -372,6 +416,73 @@ async def confluence_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/confluence/connector/reauth")
|
||||
async def reauth_confluence(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Confluence connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.CONFLUENCE_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Confluence re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Confluence re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Confluence re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_confluence_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ async def discord_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=discord_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=discord_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -311,7 +311,7 @@ async def discord_callback(
|
|||
f"Duplicate Discord connector detected for user {user_id} with server {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=discord-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=discord-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -342,7 +342,7 @@ async def discord_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ async def create_documents_file_upload(
|
|||
Upload files as documents with real-time status tracking.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately via ElectricSQL)
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately via Zero)
|
||||
- Phase 2: Celery processes each file: pending → processing → ready/failed
|
||||
|
||||
Requires DOCUMENTS_CREATE permission.
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.events"]
|
||||
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -111,6 +113,66 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/reauth")
|
||||
async def reauth_calendar(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Google Calendar re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Google Calendar connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Calendar re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Calendar re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Calendar re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/callback")
|
||||
async def calendar_callback(
|
||||
request: Request,
|
||||
|
|
@ -137,7 +199,7 @@ async def calendar_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_calendar_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -197,6 +259,42 @@ async def calendar_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Calendar connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -210,7 +308,7 @@ async def calendar_callback(
|
|||
f"Duplicate Google Calendar connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-calendar-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-calendar-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -236,7 +334,7 @@ async def calendar_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ async def drive_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_drive_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -345,6 +345,7 @@ async def drive_callback(
|
|||
db_connector.config = {
|
||||
**creds_dict,
|
||||
"start_page_token": existing_start_page_token,
|
||||
"auth_expired": False,
|
||||
}
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
|
|
@ -360,7 +361,7 @@ async def drive_callback(
|
|||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
|
|
@ -375,7 +376,7 @@ async def drive_callback(
|
|||
f"Duplicate Google Drive connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-drive-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-drive-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -425,7 +426,7 @@ async def drive_callback(
|
|||
)
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -502,11 +503,35 @@ async def list_google_drive_folders(
|
|||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Count folders and files for better logging
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
|
|
@ -515,7 +540,6 @@ async def list_google_drive_folders(
|
|||
+ (f" in folder {parent_id}" if parent_id else " in ROOT")
|
||||
)
|
||||
|
||||
# Log first few items for debugging
|
||||
if items:
|
||||
logger.info(f"First 3 items: {[item.get('name') for item in items[:3]]}")
|
||||
|
||||
|
|
@ -525,6 +549,31 @@ async def list_google_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"401" in str(e)
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -71,7 +73,7 @@ def get_google_flow():
|
|||
}
|
||||
},
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
|
|
@ -129,6 +131,66 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/reauth")
|
||||
async def reauth_gmail(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Gmail re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Gmail connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Gmail re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Gmail re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Gmail re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/callback")
|
||||
async def gmail_callback(
|
||||
request: Request,
|
||||
|
|
@ -168,7 +230,7 @@ async def gmail_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_gmail_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -228,6 +290,42 @@ async def gmail_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Gmail connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -241,7 +339,7 @@ async def gmail_callback(
|
|||
f"Duplicate Gmail connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-gmail-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-gmail-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -272,7 +370,7 @@ async def gmail_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ ACCESSIBLE_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-res
|
|||
SCOPES = [
|
||||
"read:jira-work",
|
||||
"read:jira-user",
|
||||
"write:jira-work", # Required for creating/updating/deleting issues
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ async def jira_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=jira_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=jira_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -193,6 +194,8 @@ async def jira_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.JIRA_REDIRECT_URI:
|
||||
|
|
@ -310,6 +313,46 @@ async def jira_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Jira connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=jira-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR, connector_config
|
||||
|
|
@ -328,7 +371,7 @@ async def jira_callback(
|
|||
f"Duplicate Jira connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=jira-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -359,7 +402,7 @@ async def jira_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -390,6 +433,73 @@ async def jira_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/jira/connector/reauth")
|
||||
async def reauth_jira(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Jira connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.JIRA_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Jira re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Jira re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Jira re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_jira_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import fetch_linear_organization_name
|
||||
|
|
@ -127,6 +129,70 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/reauth")
|
||||
async def reauth_linear(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Linear re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Linear connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.LINEAR_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Linear OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.LINEAR_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.LINEAR_REDIRECT_URI,
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Linear re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Linear re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Linear re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/callback")
|
||||
async def linear_callback(
|
||||
request: Request,
|
||||
|
|
@ -166,7 +232,7 @@ async def linear_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=linear_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -267,6 +333,43 @@ async def linear_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
connector_config["organization_name"] = org_name
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Linear connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same organization already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -280,7 +383,7 @@ async def linear_callback(
|
|||
f"Duplicate Linear connector detected for user {user_id} with org {org_name}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=linear-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=linear-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -292,6 +395,7 @@ async def linear_callback(
|
|||
org_name,
|
||||
)
|
||||
# Create new connector
|
||||
connector_config["organization_name"] = org_name
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
|
|
@ -311,7 +415,7 @@ async def linear_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -342,6 +446,22 @@ async def linear_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_linear_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -375,6 +495,7 @@ async def refresh_linear_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -417,6 +538,7 @@ async def refresh_linear_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Linear authentication failed. Please re-authenticate.",
|
||||
|
|
@ -453,10 +575,16 @@ async def refresh_linear_token(
|
|||
credentials.expires_at = expires_at
|
||||
credentials.scope = token_json.get("scope")
|
||||
|
||||
# Update connector config with encrypted tokens
|
||||
# Update connector config with encrypted tokens, preserving non-credential fields
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
if connector.config.get("organization_name"):
|
||||
credentials_dict["organization_name"] = connector.config[
|
||||
"organization_name"
|
||||
]
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Notifications API routes.
|
||||
These endpoints allow marking notifications as read and fetching older notifications.
|
||||
Electric SQL automatically syncs the changes to all connected clients for recent items.
|
||||
Zero automatically syncs the changes to all connected clients for recent items.
|
||||
For older items (beyond the sync window), use the list endpoint.
|
||||
"""
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ async def get_unread_count(
|
|||
|
||||
This allows the frontend to calculate:
|
||||
- older_unread = total_unread - recent_unread (static until reconciliation)
|
||||
- Display count = older_unread + live_recent_count (from Electric SQL)
|
||||
- Display count = older_unread + live_recent_count (from Zero)
|
||||
"""
|
||||
# Calculate cutoff date for sync window
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
|
@ -344,7 +344,7 @@ async def list_notifications(
|
|||
List notifications for the current user with pagination.
|
||||
|
||||
This endpoint is used as a fallback for older notifications that are
|
||||
outside the Electric SQL sync window (2 weeks).
|
||||
outside the Zero sync window (2 weeks).
|
||||
|
||||
Use `before_date` to paginate through older notifications efficiently.
|
||||
"""
|
||||
|
|
@ -487,7 +487,7 @@ async def mark_notification_as_read(
|
|||
"""
|
||||
Mark a single notification as read.
|
||||
|
||||
Electric SQL will automatically sync this change to all connected clients.
|
||||
Zero will automatically sync this change to all connected clients.
|
||||
"""
|
||||
# Verify the notification belongs to the user
|
||||
result = await session.execute(
|
||||
|
|
@ -528,7 +528,7 @@ async def mark_all_notifications_as_read(
|
|||
"""
|
||||
Mark all notifications as read for the current user.
|
||||
|
||||
Electric SQL will automatically sync these changes to all connected clients.
|
||||
Zero will automatically sync these changes to all connected clients.
|
||||
"""
|
||||
# Update all unread notifications for the user
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -124,6 +126,70 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/reauth")
|
||||
async def reauth_notion(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Notion re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Notion connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.NOTION_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Notion OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.NOTION_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"owner": "user",
|
||||
"redirect_uri": config.NOTION_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Notion re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Notion re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Notion re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/callback")
|
||||
async def notion_callback(
|
||||
request: Request,
|
||||
|
|
@ -163,7 +229,7 @@ async def notion_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=notion_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -266,6 +332,42 @@ async def notion_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Notion connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR, connector_config
|
||||
|
|
@ -284,7 +386,7 @@ async def notion_callback(
|
|||
f"Duplicate Notion connector detected for user {user_id} with workspace {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=notion-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=notion-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -315,7 +417,7 @@ async def notion_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -346,6 +448,22 @@ async def notion_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_notion_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -379,6 +497,7 @@ async def refresh_notion_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -421,6 +540,7 @@ async def refresh_notion_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Notion authentication failed. Please re-authenticate.",
|
||||
|
|
@ -469,7 +589,9 @@ async def refresh_notion_token(
|
|||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ from app.tasks.connector_indexers import (
|
|||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import ensure_unique_connector_name
|
||||
from app.utils.indexing_locks import (
|
||||
acquire_connector_indexing_lock,
|
||||
release_connector_indexing_lock,
|
||||
|
|
@ -189,6 +190,12 @@ async def create_search_source_connector(
|
|||
# Prepare connector data
|
||||
connector_data = connector.model_dump()
|
||||
|
||||
# MCP connectors support multiple instances — ensure unique name
|
||||
if connector.connector_type == SearchSourceConnectorType.MCP_CONNECTOR:
|
||||
connector_data["name"] = await ensure_unique_connector_name(
|
||||
session, connector_data["name"], search_space_id, user.id
|
||||
)
|
||||
|
||||
# Automatically set next_scheduled_at if periodic indexing is enabled
|
||||
if (
|
||||
connector.periodic_indexing_enabled
|
||||
|
|
@ -949,23 +956,46 @@ async def index_connector_content(
|
|||
index_google_drive_files_task,
|
||||
)
|
||||
|
||||
if not drive_items or not drive_items.has_items():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive indexing requires drive_items body parameter with folders or files",
|
||||
if drive_items and drive_items.has_items():
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
items_dict = drive_items.model_dump()
|
||||
else:
|
||||
# Quick Index / periodic sync: fall back to stored config
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Drive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
|
||||
# Pass structured data to Celery task
|
||||
index_google_drive_files_task.delay(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
drive_items.model_dump(), # Convert to dict for JSON serialization
|
||||
items_dict,
|
||||
)
|
||||
response_message = "Google Drive indexing started in the background."
|
||||
|
||||
|
|
@ -1061,7 +1091,7 @@ async def index_connector_content(
|
|||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_composio_connector_task,
|
||||
index_google_drive_files_task,
|
||||
)
|
||||
|
||||
# For Composio Google Drive, if drive_items is provided, update connector config
|
||||
|
|
@ -1095,34 +1125,72 @@ async def index_connector_content(
|
|||
else:
|
||||
logger.info(
|
||||
f"Triggering Composio Google Drive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config (from {indexing_from} to {indexing_to})"
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
index_composio_connector_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
# Extract config and build items_dict for index_google_drive_files_task
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Composio Google Drive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
index_google_drive_files_task.delay(
|
||||
connector_id, search_space_id, str(user.id), items_dict
|
||||
)
|
||||
response_message = (
|
||||
"Composio Google Drive indexing started in the background."
|
||||
)
|
||||
|
||||
elif connector.connector_type in [
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]:
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_composio_connector_task,
|
||||
index_google_gmail_messages_task,
|
||||
)
|
||||
|
||||
# For Composio Gmail and Calendar, use the same date calculation logic as normal connectors
|
||||
# This ensures consistent behavior and uses last_indexed_at to reduce API calls
|
||||
# (includes special case: if indexed today, go back 1 day to avoid missing data)
|
||||
logger.info(
|
||||
f"Triggering Composio connector indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
f"Triggering Composio Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_composio_connector_task.delay(
|
||||
index_google_gmail_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Composio connector indexing started in the background."
|
||||
response_message = "Composio Gmail indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_calendar_events_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Composio Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_calendar_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = (
|
||||
"Composio Google Calendar indexing started in the background."
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -1229,6 +1297,48 @@ async def run_slack_indexing(
|
|||
)
|
||||
|
||||
|
||||
_AUTH_ERROR_PATTERNS = (
|
||||
"failed to refresh linear oauth",
|
||||
"failed to refresh your notion connection",
|
||||
"failed to refresh notion token",
|
||||
"authentication failed",
|
||||
"auth_expired",
|
||||
"token has been expired or revoked",
|
||||
"invalid_grant",
|
||||
)
|
||||
|
||||
|
||||
def _is_auth_error(error_message: str) -> bool:
|
||||
"""Check if an error message indicates an OAuth token expiry failure."""
|
||||
if not error_message:
|
||||
return False
|
||||
lower = error_message.lower()
|
||||
return any(pattern in lower for pattern in _AUTH_ERROR_PATTERNS)
|
||||
|
||||
|
||||
async def _persist_auth_expired(session: AsyncSession, connector_id: int) -> None:
|
||||
"""Flag a connector as auth_expired so the frontend shows a re-auth prompt."""
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _run_indexing_with_notifications(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
|
|
@ -1433,7 +1543,7 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
) # Commit to ensure Zero syncs the notification update
|
||||
elif documents_processed > 0:
|
||||
# Update notification to storing stage
|
||||
if notification:
|
||||
|
|
@ -1460,7 +1570,7 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
) # Commit to ensure Zero syncs the notification update
|
||||
else:
|
||||
# No new documents processed - check if this is an error or just no changes
|
||||
if error_or_warning:
|
||||
|
|
@ -1486,7 +1596,7 @@ async def _run_indexing_with_notifications(
|
|||
if is_duplicate_warning or is_empty_result or is_info_warning:
|
||||
# These are success cases - sync worked, just found nothing new
|
||||
logger.info(f"Indexing completed successfully: {error_or_warning}")
|
||||
# Still update timestamp so ElectricSQL syncs and clears "Syncing" UI
|
||||
# Still update timestamp so Zero syncs and clears "Syncing" UI
|
||||
if update_timestamp_func:
|
||||
await update_timestamp_func(session, connector_id)
|
||||
await session.commit() # Commit timestamp update
|
||||
|
|
@ -1509,10 +1619,12 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
) # Commit to ensure Zero syncs the notification update
|
||||
else:
|
||||
# Actual failure
|
||||
logger.error(f"Indexing failed: {error_or_warning}")
|
||||
if _is_auth_error(str(error_or_warning)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
if notification:
|
||||
# Refresh notification to ensure it's not stale after indexing function commits
|
||||
await session.refresh(notification)
|
||||
|
|
@ -1525,13 +1637,13 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
) # Commit to ensure Zero syncs the notification update
|
||||
else:
|
||||
# Success - just no new documents to index (all skipped/unchanged)
|
||||
logger.info(
|
||||
"Indexing completed: No new documents to process (all up to date)"
|
||||
)
|
||||
# Still update timestamp so ElectricSQL syncs and clears "Syncing" UI
|
||||
# Still update timestamp so Zero syncs and clears "Syncing" UI
|
||||
if update_timestamp_func:
|
||||
await update_timestamp_func(session, connector_id)
|
||||
await session.commit() # Commit timestamp update
|
||||
|
|
@ -1547,7 +1659,7 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
) # Commit to ensure Zero syncs the notification update
|
||||
except SoftTimeLimitExceeded:
|
||||
# Celery soft time limit was reached - task is about to be killed
|
||||
# Gracefully save progress and mark as interrupted
|
||||
|
|
@ -1577,6 +1689,9 @@ async def _run_indexing_with_notifications(
|
|||
except Exception as e:
|
||||
logger.error(f"Error in indexing task: {e!s}", exc_info=True)
|
||||
|
||||
if _is_auth_error(str(e)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
|
||||
# Update notification on exception
|
||||
if notification:
|
||||
try:
|
||||
|
|
@ -2172,10 +2287,9 @@ async def run_google_gmail_indexing(
|
|||
end_date: str | None,
|
||||
update_last_indexed: bool,
|
||||
on_heartbeat_callback=None,
|
||||
) -> tuple[int, str | None]:
|
||||
# Use a reasonable default for max_messages
|
||||
) -> tuple[int, int, str | None]:
|
||||
max_messages = 1000
|
||||
indexed_count, error_message = await index_google_gmail_messages(
|
||||
indexed_count, skipped_count, error_message = await index_google_gmail_messages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -2186,8 +2300,7 @@ async def run_google_gmail_indexing(
|
|||
max_messages=max_messages,
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
# index_google_gmail_messages returns (int, str) but we need (int, str | None)
|
||||
return indexed_count, error_message if error_message else None
|
||||
return indexed_count, skipped_count, error_message if error_message else None
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
|
|
@ -2223,6 +2336,7 @@ async def run_google_drive_indexing(
|
|||
items = GoogleDriveIndexRequest(**items_dict)
|
||||
indexing_options = items.indexing_options
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
errors = []
|
||||
|
||||
# Get connector info for notification
|
||||
|
|
@ -2260,7 +2374,11 @@ async def run_google_drive_indexing(
|
|||
# Index each folder with indexing options
|
||||
for folder in items.folders:
|
||||
try:
|
||||
indexed_count, error_message = await index_google_drive_files(
|
||||
(
|
||||
indexed_count,
|
||||
skipped_count,
|
||||
error_message,
|
||||
) = await index_google_drive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2272,6 +2390,7 @@ async def run_google_drive_indexing(
|
|||
max_files=indexing_options.max_files_per_folder,
|
||||
include_subfolders=indexing_options.include_subfolders,
|
||||
)
|
||||
total_skipped += skipped_count
|
||||
if error_message:
|
||||
errors.append(f"Folder '{folder.name}': {error_message}")
|
||||
else:
|
||||
|
|
@ -2312,9 +2431,15 @@ async def run_google_drive_indexing(
|
|||
logger.error(
|
||||
f"Google Drive indexing completed with errors for connector {connector_id}: {error_message}"
|
||||
)
|
||||
if _is_auth_error(error_message):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
error_message = (
|
||||
"Google Drive authentication expired. Please re-authenticate."
|
||||
)
|
||||
else:
|
||||
# Update notification to storing stage
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
|
|
@ -2338,6 +2463,7 @@ async def run_google_drive_indexing(
|
|||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2650,7 +2776,7 @@ async def run_composio_indexing(
|
|||
Run Composio connector indexing with real-time notifications.
|
||||
|
||||
This wraps the Composio indexer with the notification system so that
|
||||
Electric SQL can sync indexing progress to the frontend in real-time.
|
||||
Zero can sync indexing progress to the frontend in real-time.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
|
@ -2715,9 +2841,14 @@ async def create_mcp_connector(
|
|||
"You don't have permission to create connectors in this search space",
|
||||
)
|
||||
|
||||
# Ensure unique name across MCP connectors in this search space
|
||||
unique_name = await ensure_unique_connector_name(
|
||||
session, connector_data.name, search_space_id, user.id
|
||||
)
|
||||
|
||||
# Create the connector with single server config
|
||||
db_connector = SearchSourceConnector(
|
||||
name=connector_data.name,
|
||||
name=unique_name,
|
||||
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
is_indexable=False, # MCP connectors are not indexable
|
||||
config={"server_config": connector_data.server_config.model_dump()},
|
||||
|
|
@ -3136,6 +3267,12 @@ async def get_drive_picker_token(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Drive picker token: {e!s}", exc_info=True)
|
||||
if _is_auth_error(str(e)):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to retrieve access token. Check server logs for details.",
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ async def slack_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=slack_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=slack_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -296,7 +296,7 @@ async def slack_callback(
|
|||
f"Duplicate Slack connector detected for user {user_id} with workspace {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=slack-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=slack-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -328,7 +328,7 @@ async def slack_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=slack-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=slack-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -456,7 +456,7 @@ async def create_comment(
|
|||
thread = message.thread
|
||||
comment = ChatComment(
|
||||
message_id=message_id,
|
||||
thread_id=thread.id, # Denormalized for efficient Electric subscriptions
|
||||
thread_id=thread.id, # Denormalized for efficient per-thread sync
|
||||
author_id=user.id,
|
||||
content=content,
|
||||
)
|
||||
|
|
@ -569,7 +569,7 @@ async def create_reply(
|
|||
thread = parent_comment.message.thread
|
||||
reply = ChatComment(
|
||||
message_id=parent_comment.message_id,
|
||||
thread_id=thread.id, # Denormalized for efficient Electric subscriptions
|
||||
thread_id=thread.id, # Denormalized for efficient per-thread sync
|
||||
parent_id=comment_id,
|
||||
author_id=user.id,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -36,32 +36,14 @@ TOOLKIT_TO_CONNECTOR_TYPE = {
|
|||
}
|
||||
|
||||
# Mapping of toolkit IDs to document types
|
||||
TOOLKIT_TO_DOCUMENT_TYPE = {
|
||||
"googledrive": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"gmail": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"googlecalendar": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
# Google Drive, Gmail, Calendar use unified native indexers - not in this registry
|
||||
TOOLKIT_TO_DOCUMENT_TYPE: dict[str, str] = {}
|
||||
|
||||
# Mapping of toolkit IDs to their indexer functions
|
||||
# Format: toolkit_id -> (module_path, function_name, supports_date_filter)
|
||||
# supports_date_filter: True if the indexer accepts start_date/end_date params
|
||||
TOOLKIT_TO_INDEXER = {
|
||||
"googledrive": (
|
||||
"app.connectors.composio_google_drive_connector",
|
||||
"index_composio_google_drive",
|
||||
False, # Google Drive doesn't use date filtering
|
||||
),
|
||||
"gmail": (
|
||||
"app.connectors.composio_gmail_connector",
|
||||
"index_composio_gmail",
|
||||
True, # Gmail uses date filtering
|
||||
),
|
||||
"googlecalendar": (
|
||||
"app.connectors.composio_google_calendar_connector",
|
||||
"index_composio_google_calendar",
|
||||
True, # Calendar uses date filtering
|
||||
),
|
||||
}
|
||||
# Google Drive, Gmail, Calendar use unified native indexers - not in this registry
|
||||
TOOLKIT_TO_INDEXER: dict[str, tuple[str, str, bool]] = {}
|
||||
|
||||
|
||||
class ComposioService:
|
||||
|
|
@ -247,6 +229,68 @@ class ComposioService:
|
|||
)
|
||||
return False
|
||||
|
||||
def refresh_connected_account(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
redirect_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired Composio connected account.
|
||||
|
||||
For OAuth flows this returns a redirect_url the user must visit to
|
||||
re-authorise. The same connected_account_id stays valid afterwards.
|
||||
|
||||
Args:
|
||||
connected_account_id: The Composio connected account nanoid.
|
||||
redirect_url: Where Composio should redirect after re-auth.
|
||||
|
||||
Returns:
|
||||
Dict with id, status, and redirect_url (None when no redirect needed).
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if redirect_url is not None:
|
||||
kwargs["body_redirect_url"] = redirect_url
|
||||
result = self.client.connected_accounts.refresh(
|
||||
nanoid=connected_account_id,
|
||||
**kwargs,
|
||||
)
|
||||
return {
|
||||
"id": result.id,
|
||||
"status": result.status,
|
||||
"redirect_url": result.redirect_url,
|
||||
}
|
||||
|
||||
def wait_for_connection(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""
|
||||
Poll Composio until the connected account reaches ACTIVE status.
|
||||
|
||||
Must be called after refresh() / initiate() to ensure Composio has
|
||||
finished exchanging the authorization code for valid tokens.
|
||||
|
||||
Returns:
|
||||
The final account status string (should be "ACTIVE").
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the account does not become ACTIVE within *timeout*.
|
||||
"""
|
||||
try:
|
||||
account = self.client.connected_accounts.wait_for_connection(
|
||||
id=connected_account_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
status = getattr(account, "status", "UNKNOWN")
|
||||
logger.info(f"Composio account {connected_account_id} is now {status}")
|
||||
return status
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Timeout/error waiting for Composio account {connected_account_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_access_token(self, connected_account_id: str) -> str:
|
||||
"""Retrieve the raw OAuth access token for a Composio connected account."""
|
||||
account = self.client.connected_accounts.get(nanoid=connected_account_id)
|
||||
|
|
@ -258,6 +302,12 @@ class ComposioService:
|
|||
access_token = getattr(token, "access_token", None)
|
||||
if not access_token:
|
||||
raise ValueError(f"No access_token in state.val for {connected_account_id}")
|
||||
if len(access_token) < 20:
|
||||
raise ValueError(
|
||||
f"Composio returned a masked access_token ({len(access_token)} chars) "
|
||||
f"for account {connected_account_id}. Disable 'Mask Connected Account "
|
||||
f"Secrets' in Composio dashboard: Settings → Project Settings."
|
||||
)
|
||||
return access_token
|
||||
|
||||
async def execute_tool(
|
||||
|
|
|
|||
13
surfsense_backend/app/services/confluence/__init__.py
Normal file
13
surfsense_backend/app/services/confluence/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.confluence.kb_sync_service import ConfluenceKBSyncService
|
||||
from app.services.confluence.tool_metadata_service import (
|
||||
ConfluencePage,
|
||||
ConfluenceToolMetadataService,
|
||||
ConfluenceWorkspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfluenceKBSyncService",
|
||||
"ConfluencePage",
|
||||
"ConfluenceToolMetadataService",
|
||||
"ConfluenceWorkspace",
|
||||
]
|
||||
240
surfsense_backend/app/services/confluence/kb_sync_service.py
Normal file
240
surfsense_backend/app/services/confluence/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfluenceKBSyncService:
|
||||
"""Syncs Confluence page documents to the knowledge base after HITL actions."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
page_id: str,
|
||||
page_title: str,
|
||||
space_id: str,
|
||||
body_content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (body_content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Confluence Page: {page_title}"
|
||||
|
||||
page_content = f"# {page_title}\n\n{indexable_content}"
|
||||
|
||||
content_hash = generate_content_hash(page_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
page_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(page_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=page_title,
|
||||
document_type=DocumentType.CONFLUENCE_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"comment_count": 0,
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, page=%s",
|
||||
document.id,
|
||||
page_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for page %s: %s",
|
||||
page_title,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
page_id: str,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
connector_id = document.connector_id
|
||||
if not connector_id:
|
||||
return {"status": "error", "message": "Document has no connector_id"}
|
||||
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self.db_session, connector_id=connector_id
|
||||
)
|
||||
page_data = await client.get_page(page_id)
|
||||
await client.close()
|
||||
|
||||
page_title = page_data.get("title", "")
|
||||
body_obj = page_data.get("body", {})
|
||||
body_content = ""
|
||||
if isinstance(body_obj, dict):
|
||||
storage = body_obj.get("storage", {})
|
||||
if isinstance(storage, dict):
|
||||
body_content = storage.get("value", "")
|
||||
|
||||
page_content = f"# {page_title}\n\n{body_content}"
|
||||
|
||||
if not page_content.strip():
|
||||
return {"status": "error", "message": "Page produced empty content"}
|
||||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
doc_meta = {
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"document_type": "Confluence Page",
|
||||
"connector_type": "Confluence",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
page_content, user_llm, doc_meta
|
||||
)
|
||||
else:
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(page_content)
|
||||
|
||||
document.title = page_title
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(page_content, search_space_id)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"page_id": page_id,
|
||||
"page_title": page_title,
|
||||
"space_id": space_id,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync successful for document %s (%s)",
|
||||
document_id,
|
||||
page_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync failed for document %s: %s", document_id, e, exc_info=True
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfluenceWorkspace:
|
||||
"""Represents a Confluence connector as a workspace for tool context."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "ConfluenceWorkspace":
|
||||
return cls(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
base_url=connector.config.get("base_url", ""),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfluencePage:
|
||||
"""Represents an indexed Confluence page resolved from the knowledge base."""
|
||||
|
||||
page_id: str
|
||||
page_title: str
|
||||
space_id: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "ConfluencePage":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
page_id=meta.get("page_id", ""),
|
||||
page_title=meta.get("page_title", document.title),
|
||||
space_id=meta.get("space_id", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"page_id": self.page_id,
|
||||
"page_title": self.page_title,
|
||||
"space_id": self.space_id,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class ConfluenceToolMetadataService:
|
||||
"""Builds interrupt context for Confluence HITL tools."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
|
||||
"""Check if the Confluence connector auth is still valid.
|
||||
|
||||
Returns True if auth is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
await client._get_valid_token()
|
||||
await client.close()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Confluence connector %s health check failed: %s", connector.id, e
|
||||
)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
"""Return context needed to create a new Confluence page.
|
||||
|
||||
Fetches all connected accounts, and for the first healthy one fetches spaces.
|
||||
"""
|
||||
connectors = await self._get_all_confluence_connectors(search_space_id, user_id)
|
||||
if not connectors:
|
||||
return {"error": "No Confluence account connected"}
|
||||
|
||||
accounts = []
|
||||
spaces = []
|
||||
fetched_context = False
|
||||
|
||||
for connector in connectors:
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
accounts.append(
|
||||
{
|
||||
**workspace.to_dict(),
|
||||
"auth_expired": auth_expired,
|
||||
}
|
||||
)
|
||||
|
||||
if not auth_expired and not fetched_context:
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
raw_spaces = await client.get_all_spaces()
|
||||
spaces = [
|
||||
{"id": s.get("id"), "key": s.get("key"), "name": s.get("name")}
|
||||
for s in raw_spaces
|
||||
]
|
||||
await client.close()
|
||||
fetched_context = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Confluence spaces for connector %s: %s",
|
||||
connector.id,
|
||||
e,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts,
|
||||
"spaces": spaces,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to update an indexed Confluence page.
|
||||
|
||||
Resolves the page from KB, then fetches current content and version from API.
|
||||
"""
|
||||
document = await self._resolve_page(search_space_id, user_id, page_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
|
||||
"Please make sure the page is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
if auth_expired:
|
||||
return {
|
||||
"error": "Confluence authentication has expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
page = ConfluencePage.from_document(document)
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
page_data = await client.get_page(page.page_id)
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "403" in error_str
|
||||
or "authentication" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Confluence page: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Confluence page: {e!s}"}
|
||||
|
||||
body_storage = ""
|
||||
body_obj = page_data.get("body", {})
|
||||
if isinstance(body_obj, dict):
|
||||
storage = body_obj.get("storage", {})
|
||||
if isinstance(storage, dict):
|
||||
body_storage = storage.get("value", "")
|
||||
|
||||
version_obj = page_data.get("version", {})
|
||||
version_number = (
|
||||
version_obj.get("number", 1) if isinstance(version_obj, dict) else 1
|
||||
)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": False},
|
||||
"page": {
|
||||
"page_id": page.page_id,
|
||||
"page_title": page_data.get("title", page.page_title),
|
||||
"space_id": page.space_id,
|
||||
"body": body_storage,
|
||||
"version": version_number,
|
||||
"document_id": page.document_id,
|
||||
"indexed_at": page.indexed_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to delete a Confluence page (KB metadata only)."""
|
||||
document = await self._resolve_page(search_space_id, user_id, page_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Page '{page_ref}' not found in your synced Confluence pages. "
|
||||
"Please make sure the page is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = connector.config.get("auth_expired", False)
|
||||
workspace = ConfluenceWorkspace.from_connector(connector)
|
||||
page = ConfluencePage.from_document(document)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
|
||||
"page": page.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_page(
|
||||
self, search_space_id: int, user_id: str, page_ref: str
|
||||
) -> Document | None:
|
||||
"""Resolve a page from KB: page_title -> document.title."""
|
||||
ref_lower = page_ref.lower()
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.CONFLUENCE_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(Document.document_metadata.op("->>")("page_title"))
|
||||
== ref_lower,
|
||||
func.lower(Document.title) == ref_lower,
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_all_confluence_connectors(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[SearchSourceConnector]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def _get_connector_for_document(
|
||||
self, document: Document, user_id: str
|
||||
) -> SearchSourceConnector | None:
|
||||
if not document.connector_id:
|
||||
return None
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -11,6 +11,7 @@ from sqlalchemy.future import select
|
|||
from tavily import TavilyClient
|
||||
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
SearchSourceConnector,
|
||||
|
|
@ -219,7 +220,7 @@ class ConnectorService:
|
|||
self,
|
||||
query_text: str,
|
||||
search_space_id: int,
|
||||
document_type: str,
|
||||
document_type: str | list[str],
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
|
|
@ -241,7 +242,8 @@ class ConnectorService:
|
|||
Args:
|
||||
query_text: The search query text
|
||||
search_space_id: The search space ID to search within
|
||||
document_type: Document type to filter (e.g., "FILE", "CRAWLED_URL")
|
||||
document_type: Document type(s) to filter (e.g., "FILE", "CRAWLED_URL",
|
||||
or a list for multi-type queries)
|
||||
top_k: Number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
|
@ -254,6 +256,16 @@ class ConnectorService:
|
|||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Expand native Google types to include legacy Composio equivalents
|
||||
# so old documents remain searchable until re-indexed.
|
||||
if isinstance(document_type, str) and document_type in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
resolved_type: str | list[str] = [
|
||||
document_type,
|
||||
NATIVE_TO_LEGACY_DOCTYPE[document_type],
|
||||
]
|
||||
else:
|
||||
resolved_type = document_type
|
||||
|
||||
# RRF constant
|
||||
k = 60
|
||||
|
||||
|
|
@ -276,7 +288,7 @@ class ConnectorService:
|
|||
"query_text": query_text,
|
||||
"top_k": retriever_top_k,
|
||||
"search_space_id": search_space_id,
|
||||
"document_type": document_type,
|
||||
"document_type": resolved_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"query_embedding": query_embedding,
|
||||
|
|
@ -2746,299 +2758,6 @@ class ConnectorService:
|
|||
|
||||
return result_object, obsidian_docs
|
||||
|
||||
# =========================================================================
|
||||
# Composio Connector Search Methods
|
||||
# =========================================================================
|
||||
|
||||
async def search_composio_google_drive(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Google Drive files and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_drive_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_drive_docs:
|
||||
return {
|
||||
"id": 54,
|
||||
"name": "Google Drive (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("title")
|
||||
or metadata.get("file_name")
|
||||
or "Untitled Document"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or metadata.get("web_view_link") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
mime_type = metadata.get("mime_type")
|
||||
modified_time = metadata.get("modified_time")
|
||||
if mime_type:
|
||||
info_parts.append(f"Type: {mime_type}")
|
||||
if modified_time:
|
||||
info_parts.append(f"Modified: {modified_time}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"mime_type": metadata.get("mime_type", ""),
|
||||
"file_id": metadata.get("file_id", ""),
|
||||
"modified_time": metadata.get("modified_time", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_drive_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 54,
|
||||
"name": "Google Drive (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_drive_docs
|
||||
|
||||
async def search_composio_gmail(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Gmail messages and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_gmail_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GMAIL_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_gmail_docs:
|
||||
return {
|
||||
"id": 55,
|
||||
"name": "Gmail (Composio)",
|
||||
"type": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("subject")
|
||||
or metadata.get("title")
|
||||
or "Untitled Email"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
sender = metadata.get("from") or metadata.get("sender")
|
||||
date = metadata.get("date") or metadata.get("received_at")
|
||||
if sender:
|
||||
info_parts.append(f"From: {sender}")
|
||||
if date:
|
||||
info_parts.append(f"Date: {date}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"message_id": metadata.get("message_id", ""),
|
||||
"thread_id": metadata.get("thread_id", ""),
|
||||
"from": metadata.get("from", ""),
|
||||
"to": metadata.get("to", ""),
|
||||
"date": metadata.get("date", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_gmail_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 55,
|
||||
"name": "Gmail (Composio)",
|
||||
"type": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_gmail_docs
|
||||
|
||||
async def search_composio_google_calendar(
|
||||
self,
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for Composio Google Calendar events and return both the source information
|
||||
and langchain documents.
|
||||
|
||||
Uses combined chunk-level and document-level hybrid search with RRF fusion.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
composio_calendar_docs = await self._combined_rrf_search(
|
||||
query_text=user_query,
|
||||
search_space_id=search_space_id,
|
||||
document_type="COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Early return if no results
|
||||
if not composio_calendar_docs:
|
||||
return {
|
||||
"id": 56,
|
||||
"name": "Google Calendar (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
def _title_fn(doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return (
|
||||
doc_info.get("title")
|
||||
or metadata.get("summary")
|
||||
or metadata.get("title")
|
||||
or "Untitled Event"
|
||||
)
|
||||
|
||||
def _url_fn(_doc_info: dict[str, Any], metadata: dict[str, Any]) -> str:
|
||||
return metadata.get("url") or metadata.get("html_link") or ""
|
||||
|
||||
def _description_fn(
|
||||
chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> str:
|
||||
description = self._chunk_preview(chunk.get("content", ""), limit=200)
|
||||
info_parts = []
|
||||
start_time = metadata.get("start_time") or metadata.get("start")
|
||||
end_time = metadata.get("end_time") or metadata.get("end")
|
||||
if start_time:
|
||||
info_parts.append(f"Start: {start_time}")
|
||||
if end_time:
|
||||
info_parts.append(f"End: {end_time}")
|
||||
if info_parts:
|
||||
description = (description + " | " + " | ".join(info_parts)).strip(" |")
|
||||
return description
|
||||
|
||||
def _extra_fields_fn(
|
||||
_chunk: dict[str, Any], _doc_info: dict[str, Any], metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"event_id": metadata.get("event_id", ""),
|
||||
"calendar_id": metadata.get("calendar_id", ""),
|
||||
"start_time": metadata.get("start_time", ""),
|
||||
"end_time": metadata.get("end_time", ""),
|
||||
"location": metadata.get("location", ""),
|
||||
}
|
||||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
composio_calendar_docs,
|
||||
title_fn=_title_fn,
|
||||
url_fn=_url_fn,
|
||||
description_fn=_description_fn,
|
||||
extra_fields_fn=_extra_fields_fn,
|
||||
)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 56,
|
||||
"name": "Google Calendar (Composio)",
|
||||
"type": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, composio_calendar_docs
|
||||
|
||||
# =========================================================================
|
||||
# Utility Methods for Connector Discovery
|
||||
# =========================================================================
|
||||
|
|
|
|||
13
surfsense_backend/app/services/gmail/__init__.py
Normal file
13
surfsense_backend/app/services/gmail/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.gmail.kb_sync_service import GmailKBSyncService
|
||||
from app.services.gmail.tool_metadata_service import (
|
||||
GmailAccount,
|
||||
GmailMessage,
|
||||
GmailToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GmailAccount",
|
||||
"GmailKBSyncService",
|
||||
"GmailMessage",
|
||||
"GmailToolMetadataService",
|
||||
]
|
||||
169
surfsense_backend/app/services/gmail/kb_sync_service.py
Normal file
169
surfsense_backend/app/services/gmail/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GmailKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
subject: str,
|
||||
sender: str,
|
||||
date_str: str,
|
||||
body_text: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
draft_id: str | None = None,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Gmail message %s already exists (doc_id=%s), skipping",
|
||||
message_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (
|
||||
f"Gmail Message: {subject}\n\nFrom: {sender}\nDate: {date_str}\n\n"
|
||||
f"{body_text or ''}"
|
||||
).strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Gmail message: {subject}"
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Gmail message %s -- identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
message_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"document_type": "Gmail Message",
|
||||
"connector_type": "Gmail",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured -- using fallback summary")
|
||||
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
doc_metadata = {
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"connector_id": connector_id,
|
||||
"indexed_at": now_str,
|
||||
}
|
||||
if draft_id:
|
||||
doc_metadata["draft_id"] = draft_id
|
||||
|
||||
document = Document(
|
||||
title=subject,
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
document_metadata=doc_metadata,
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=body_text,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, subject=%s, chunks=%d",
|
||||
document.id,
|
||||
subject,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for message %s. "
|
||||
"Rolling back -- periodic indexer will handle it. Error: %s",
|
||||
message_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for message %s: %s",
|
||||
message_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
451
surfsense_backend/app/services/gmail/tool_metadata_service.py
Normal file
451
surfsense_backend/app/services/gmail/tool_metadata_service.py
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy import String, and_, cast, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GmailAccount:
|
||||
id: int
|
||||
name: str
|
||||
email: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "GmailAccount":
|
||||
email = ""
|
||||
if connector.name and " - " in connector.name:
|
||||
email = connector.name.split(" - ", 1)[1]
|
||||
return cls(id=connector.id, name=connector.name, email=email)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name, "email": self.email}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GmailMessage:
|
||||
message_id: str
|
||||
thread_id: str
|
||||
subject: str
|
||||
sender: str
|
||||
date: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GmailMessage":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
message_id=meta.get("message_id", ""),
|
||||
thread_id=meta.get("thread_id", ""),
|
||||
subject=meta.get("subject", document.title),
|
||||
sender=meta.get("sender", ""),
|
||||
date=meta.get("date", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"thread_id": self.thread_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"date": self.date,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
}
|
||||
|
||||
|
||||
class GmailToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Gmail connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight ``users().getProfile(userId='me')`` call.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: service.users().getProfile(userId="me").execute()
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Gmail connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GmailAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GmailAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"error": "No Gmail account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
else:
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == acc.id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
acc_dict["email"] = profile.get("emailAddress", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch email for Gmail connector %s",
|
||||
acc.id,
|
||||
exc_info=True,
|
||||
)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
return {"accounts": accounts_with_status}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> dict:
|
||||
document, connector = await self._resolve_email(
|
||||
search_space_id, user_id, email_ref
|
||||
)
|
||||
|
||||
if not document or not connector:
|
||||
return {
|
||||
"error": (
|
||||
f"Draft '{email_ref}' not found in your indexed Gmail messages. "
|
||||
"This could mean: (1) the draft doesn't exist, "
|
||||
"(2) it hasn't been indexed yet, "
|
||||
"or (3) the subject is different. "
|
||||
"Please check the exact draft subject in Gmail."
|
||||
)
|
||||
}
|
||||
|
||||
account = GmailAccount.from_connector(connector)
|
||||
message = GmailMessage.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
result: dict = {
|
||||
"account": acc_dict,
|
||||
"email": message.to_dict(),
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
if meta.get("draft_id"):
|
||||
result["draft_id"] = meta["draft_id"]
|
||||
|
||||
if not auth_expired:
|
||||
existing_body = await self._fetch_draft_body(
|
||||
connector, message.message_id, meta.get("draft_id")
|
||||
)
|
||||
if existing_body is not None:
|
||||
result["existing_body"] = existing_body
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_draft_body(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
message_id: str,
|
||||
draft_id: str | None,
|
||||
) -> str | None:
|
||||
"""Fetch the plain-text body of a Gmail draft via the API.
|
||||
|
||||
Tries ``drafts.get`` first (if *draft_id* is available), then falls
|
||||
back to scanning ``drafts.list`` to resolve the draft by *message_id*.
|
||||
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||
"""
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
if not draft_id:
|
||||
draft_id = await self._find_draft_id(service, message_id)
|
||||
if not draft_id:
|
||||
return None
|
||||
|
||||
draft = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.users()
|
||||
.drafts()
|
||||
.get(userId="me", id=draft_id, format="full")
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
payload = draft.get("message", {}).get("payload", {})
|
||||
return self._extract_body_from_payload(payload)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch draft body for message_id=%s",
|
||||
message_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to look up draft by message_id=%s", message_id, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_body_from_payload(payload: dict) -> str | None:
|
||||
"""Extract the plain-text (or html→text) body from a Gmail payload."""
|
||||
import base64
|
||||
|
||||
def _get_parts(p: dict) -> list[dict]:
|
||||
if "parts" in p:
|
||||
parts: list[dict] = []
|
||||
for sub in p["parts"]:
|
||||
parts.extend(_get_parts(sub))
|
||||
return parts
|
||||
return [p]
|
||||
|
||||
parts = _get_parts(payload)
|
||||
text_content = ""
|
||||
for part in parts:
|
||||
mime_type = part.get("mimeType", "")
|
||||
data = part.get("body", {}).get("data", "")
|
||||
if mime_type == "text/plain" and data:
|
||||
text_content += base64.urlsafe_b64decode(data + "===").decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
elif mime_type == "text/html" and data and not text_content:
|
||||
from markdownify import markdownify as md
|
||||
|
||||
raw_html = base64.urlsafe_b64decode(data + "===").decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_content = md(raw_html).strip()
|
||||
|
||||
return text_content.strip() if text_content.strip() else None
|
||||
|
||||
async def get_trash_context(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> dict:
|
||||
document, connector = await self._resolve_email(
|
||||
search_space_id, user_id, email_ref
|
||||
)
|
||||
|
||||
if not document or not connector:
|
||||
return {
|
||||
"error": (
|
||||
f"Email '{email_ref}' not found in your indexed Gmail messages. "
|
||||
"This could mean: (1) the email doesn't exist, "
|
||||
"(2) it hasn't been indexed yet, "
|
||||
"or (3) the subject is different."
|
||||
)
|
||||
}
|
||||
|
||||
account = GmailAccount.from_connector(connector)
|
||||
message = GmailMessage.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"email": message.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_email(
|
||||
self, search_space_id: int, user_id: str, email_ref: str
|
||||
) -> tuple[Document | None, SearchSourceConnector | None]:
|
||||
result = await self._db_session.execute(
|
||||
select(Document, SearchSourceConnector)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type.in_(
|
||||
[
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(cast(Document.document_metadata["subject"], String))
|
||||
== func.lower(email_ref),
|
||||
func.lower(Document.title) == func.lower(email_ref),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0], row[1]
|
||||
return None, None
|
||||
13
surfsense_backend/app/services/google_calendar/__init__.py
Normal file
13
surfsense_backend/app/services/google_calendar/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.google_calendar.kb_sync_service import GoogleCalendarKBSyncService
|
||||
from app.services.google_calendar.tool_metadata_service import (
|
||||
GoogleCalendarAccount,
|
||||
GoogleCalendarEvent,
|
||||
GoogleCalendarToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GoogleCalendarAccount",
|
||||
"GoogleCalendarEvent",
|
||||
"GoogleCalendarKBSyncService",
|
||||
"GoogleCalendarToolMetadataService",
|
||||
]
|
||||
|
|
@ -0,0 +1,374 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleCalendarKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
event_id: str,
|
||||
event_summary: str,
|
||||
calendar_id: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
location: str | None,
|
||||
html_link: str | None,
|
||||
description: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Calendar event %s already exists (doc_id=%s), skipping",
|
||||
event_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n"
|
||||
f"Start: {start_time}\n"
|
||||
f"End: {end_time}\n"
|
||||
f"Location: {location or 'N/A'}\n\n"
|
||||
f"{description or ''}"
|
||||
).strip()
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Calendar event %s -- identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
event_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"event_summary": event_summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured -- using fallback summary")
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=event_summary,
|
||||
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"html_link": html_link,
|
||||
"source_connector": "google_calendar",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=indexable_content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, event=%s, chunks=%d",
|
||||
document.id,
|
||||
event_summary,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for event %s. "
|
||||
"Rolling back -- periodic indexer will handle it. Error: %s",
|
||||
event_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for event %s: %s",
|
||||
event_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
event_id: str,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
logger.warning("Document %s not found in KB", document_id)
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
calendar_id = (document.document_metadata or {}).get(
|
||||
"calendar_id", "primary"
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_summary = live_event.get("summary", "")
|
||||
description = live_event.get("description", "")
|
||||
location = live_event.get("location", "")
|
||||
|
||||
start_data = live_event.get("start", {})
|
||||
start_time = start_data.get("dateTime", start_data.get("date", ""))
|
||||
|
||||
end_data = live_event.get("end", {})
|
||||
end_time = end_data.get("dateTime", end_data.get("date", ""))
|
||||
|
||||
attendees = [
|
||||
{
|
||||
"email": a.get("email", ""),
|
||||
"responseStatus": a.get("responseStatus", ""),
|
||||
}
|
||||
for a in live_event.get("attendees", [])
|
||||
]
|
||||
|
||||
indexable_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n"
|
||||
f"Start: {start_time}\n"
|
||||
f"End: {end_time}\n"
|
||||
f"Location: {location or 'N/A'}\n\n"
|
||||
f"{description or ''}"
|
||||
).strip()
|
||||
|
||||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"event_summary": event_summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"document_type": "Google Calendar Event",
|
||||
"connector_type": "Google Calendar",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document.title = event_summary
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(
|
||||
indexable_content, search_space_id
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"event_id": event_id,
|
||||
"event_summary": event_summary,
|
||||
"calendar_id": calendar_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"description": description,
|
||||
"attendees": attendees,
|
||||
"html_link": live_event.get("htmlLink", ""),
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after update succeeded for document %s (event: %s)",
|
||||
document_id,
|
||||
event_summary,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync after update failed for document %s: %s",
|
||||
document_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
result = await self.db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {connector_id} not found")
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
|
@ -0,0 +1,431 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy import String, and_, cast, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CALENDAR_CONNECTOR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
CALENDAR_DOCUMENT_TYPES = [
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleCalendarAccount:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(
|
||||
cls, connector: SearchSourceConnector
|
||||
) -> "GoogleCalendarAccount":
|
||||
return cls(id=connector.id, name=connector.name)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleCalendarEvent:
|
||||
event_id: str
|
||||
summary: str
|
||||
start: str
|
||||
end: str
|
||||
description: str
|
||||
location: str
|
||||
attendees: list
|
||||
calendar_id: str
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GoogleCalendarEvent":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
event_id=meta.get("event_id", ""),
|
||||
summary=meta.get("event_summary", document.title),
|
||||
start=meta.get("start_time", ""),
|
||||
end=meta.get("end_time", ""),
|
||||
description=meta.get("description", ""),
|
||||
location=meta.get("location", ""),
|
||||
attendees=meta.get("attendees", []),
|
||||
calendar_id=meta.get("calendar_id", "primary"),
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"event_id": self.event_id,
|
||||
"summary": self.summary,
|
||||
"start": self.start,
|
||||
"end": self.end,
|
||||
"description": self.description,
|
||||
"location": self.location,
|
||||
"attendees": self.attendees,
|
||||
"calendar_id": self.calendar_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class GoogleCalendarToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
return Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Google Calendar connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight calendarList.list(maxResults=1) call to verify access.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
build("calendar", "v3", credentials=creds)
|
||||
.calendarList()
|
||||
.list(maxResults=1)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Google Calendar connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GoogleCalendarAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleCalendarAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"error": "No Google Calendar account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
healthy_account = next(
|
||||
(a for a in accounts_with_status if not a.get("auth_expired")), None
|
||||
)
|
||||
if not healthy_account:
|
||||
return {
|
||||
"accounts": accounts_with_status,
|
||||
"calendars": [],
|
||||
"timezone": "",
|
||||
"error": "All connected Google Calendar accounts have expired credentials",
|
||||
}
|
||||
|
||||
connector_id = healthy_account["id"]
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
|
||||
calendars = []
|
||||
timezone_str = ""
|
||||
if connector:
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
for cal in cal_list.get("items", []):
|
||||
calendars.append(
|
||||
{
|
||||
"id": cal.get("id", ""),
|
||||
"summary": cal.get("summary", ""),
|
||||
"primary": cal.get("primary", False),
|
||||
}
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch calendars/timezone for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts_with_status,
|
||||
"calendars": calendars,
|
||||
"timezone": timezone_str,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
event = GoogleCalendarEvent.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
return {
|
||||
"error": "Google Calendar credentials have expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
event_dict = event.to_dict()
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
calendar_id = event.calendar_id or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
||||
event_dict["description"] = live_event.get(
|
||||
"description", event_dict["description"]
|
||||
)
|
||||
event_dict["location"] = live_event.get("location", event_dict["location"])
|
||||
|
||||
start_data = live_event.get("start", {})
|
||||
event_dict["start"] = start_data.get(
|
||||
"dateTime", start_data.get("date", event_dict["start"])
|
||||
)
|
||||
|
||||
end_data = live_event.get("end", {})
|
||||
event_dict["end"] = end_data.get(
|
||||
"dateTime", end_data.get("date", event_dict["end"])
|
||||
)
|
||||
|
||||
event_dict["attendees"] = [
|
||||
{
|
||||
"email": a.get("email", ""),
|
||||
"responseStatus": a.get("responseStatus", ""),
|
||||
}
|
||||
for a in live_event.get("attendees", [])
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch live event data for event %s, using KB metadata",
|
||||
event.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"event": event_dict,
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
event = GoogleCalendarEvent.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": acc_dict,
|
||||
"event": event.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_event(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> tuple[Document, SearchSourceConnector] | None:
|
||||
result = await self._db_session.execute(
|
||||
select(Document, SearchSourceConnector)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type.in_(CALENDAR_DOCUMENT_TYPES),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(
|
||||
cast(Document.document_metadata["event_summary"], String)
|
||||
)
|
||||
== func.lower(event_ref),
|
||||
func.lower(Document.title) == func.lower(event_ref),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0], row[1]
|
||||
return None
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from app.services.google_drive.kb_sync_service import GoogleDriveKBSyncService
|
||||
from app.services.google_drive.tool_metadata_service import (
|
||||
GoogleDriveAccount,
|
||||
GoogleDriveFile,
|
||||
|
|
@ -7,5 +8,6 @@ from app.services.google_drive.tool_metadata_service import (
|
|||
__all__ = [
|
||||
"GoogleDriveAccount",
|
||||
"GoogleDriveFile",
|
||||
"GoogleDriveKBSyncService",
|
||||
"GoogleDriveToolMetadataService",
|
||||
]
|
||||
|
|
|
|||
164
surfsense_backend/app/services/google_drive/kb_sync_service.py
Normal file
164
surfsense_backend/app/services/google_drive/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleDriveKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
mime_type: str,
|
||||
web_view_link: str | None,
|
||||
content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Drive file %s already exists (doc_id=%s), skipping",
|
||||
file_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = (
|
||||
f"Google Drive file: {file_name} (type: {mime_type})"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Drive file %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
file_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"document_type": "Google Drive File",
|
||||
"connector_type": "Google Drive",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = (
|
||||
f"Google Drive File: {file_name}\n\n{indexable_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=file_name,
|
||||
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||
document_metadata={
|
||||
"google_drive_file_id": file_id,
|
||||
"google_drive_file_name": file_name,
|
||||
"google_drive_mime_type": mime_type,
|
||||
"web_view_link": web_view_link,
|
||||
"source_connector": "google_drive",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, file=%s, chunks=%d",
|
||||
document.id,
|
||||
file_name,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for file %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
file_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for file %s: %s",
|
||||
file_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
@ -1,15 +1,21 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -68,12 +74,25 @@ class GoogleDriveToolMetadataService:
|
|||
return {
|
||||
"accounts": [],
|
||||
"supported_types": [],
|
||||
"parent_folders": {},
|
||||
"error": "No Google Drive account connected",
|
||||
}
|
||||
|
||||
accounts_with_status = []
|
||||
for acc in accounts:
|
||||
acc_dict = acc.to_dict()
|
||||
auth_expired = await self._check_account_health(acc.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(acc.id)
|
||||
accounts_with_status.append(acc_dict)
|
||||
|
||||
parent_folders = await self._get_parent_folders_by_account(accounts_with_status)
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"accounts": accounts_with_status,
|
||||
"supported_types": ["google_doc", "google_sheet"],
|
||||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
async def get_trash_context(
|
||||
|
|
@ -92,6 +111,8 @@ class GoogleDriveToolMetadataService:
|
|||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
|
|
@ -112,8 +133,12 @@ class GoogleDriveToolMetadataService:
|
|||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -125,8 +150,14 @@ class GoogleDriveToolMetadataService:
|
|||
account = GoogleDriveAccount.from_connector(connector)
|
||||
file = GoogleDriveFile.from_document(document)
|
||||
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"account": acc_dict,
|
||||
"file": file.to_dict(),
|
||||
}
|
||||
|
||||
|
|
@ -139,11 +170,150 @@ class GoogleDriveToolMetadataService:
|
|||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(
|
||||
[
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleDriveAccount.from_connector(c) for c in connectors]
|
||||
|
||||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Google Drive connector's credentials are still valid.
|
||||
|
||||
Uses a lightweight ``files.list(pageSize=1)`` call to verify access.
|
||||
|
||||
Returns True if the credentials are expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
return True
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
await client.list_files(
|
||||
query="trashed = false", page_size=1, fields="files(id)"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Google Drive connector %s health check failed: %s",
|
||||
connector_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _persist_auth_expired(self, connector_id: int) -> None:
|
||||
"""Persist ``auth_expired: True`` to the connector config if not already set."""
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
db_connector = result.scalar_one_or_none()
|
||||
if db_connector and not db_connector.config.get("auth_expired"):
|
||||
db_connector.config = {**db_connector.config, "auth_expired": True}
|
||||
flag_modified(db_connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _get_parent_folders_by_account(
|
||||
self, accounts_with_status: list[dict]
|
||||
) -> dict[int, list[dict]]:
|
||||
"""Fetch root-level folders for each healthy account.
|
||||
|
||||
Skips accounts where ``auth_expired`` is True so we don't waste an API
|
||||
call that will fail anyway.
|
||||
"""
|
||||
parent_folders: dict[int, list[dict]] = {}
|
||||
|
||||
for acc in accounts_with_status:
|
||||
connector_id = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
|
||||
folders, _, error = await client.list_files(
|
||||
query="mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||
fields="files(id, name)",
|
||||
page_size=50,
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s",
|
||||
connector_id,
|
||||
error,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
else:
|
||||
parent_folders[connector_id] = [
|
||||
{"folder_id": f["id"], "name": f["name"]}
|
||||
for f in folders
|
||||
if f.get("id") and f.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
|
||||
return parent_folders
|
||||
|
|
|
|||
13
surfsense_backend/app/services/jira/__init__.py
Normal file
13
surfsense_backend/app/services/jira/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from app.services.jira.kb_sync_service import JiraKBSyncService
|
||||
from app.services.jira.tool_metadata_service import (
|
||||
JiraIssue,
|
||||
JiraToolMetadataService,
|
||||
JiraWorkspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JiraIssue",
|
||||
"JiraKBSyncService",
|
||||
"JiraToolMetadataService",
|
||||
"JiraWorkspace",
|
||||
]
|
||||
254
surfsense_backend/app/services/jira/kb_sync_service.py
Normal file
254
surfsense_backend/app/services/jira/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JiraKBSyncService:
|
||||
"""Syncs Jira issue documents to the knowledge base after HITL actions."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
issue_id: str,
|
||||
issue_identifier: str,
|
||||
issue_title: str,
|
||||
description: str | None,
|
||||
state: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Jira issue %s already exists (doc_id=%s), skipping",
|
||||
issue_identifier,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (description or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Jira Issue {issue_identifier}: {issue_title}"
|
||||
|
||||
issue_content = (
|
||||
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"issue_id": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.JIRA_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state or "Unknown",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, issue=%s",
|
||||
document.id,
|
||||
issue_identifier,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for issue %s: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
issue_id: str,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
document = await self.db_session.get(Document, document_id)
|
||||
if not document:
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
connector_id = document.connector_id
|
||||
if not connector_id:
|
||||
return {"status": "error", "message": "Document has no connector_id"}
|
||||
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self.db_session, connector_id=connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
issue_raw = await asyncio.to_thread(jira_client.get_issue, issue_id)
|
||||
formatted = jira_client.format_issue(issue_raw)
|
||||
issue_content = jira_client.format_issue_to_markdown(formatted)
|
||||
|
||||
if not issue_content:
|
||||
return {"status": "error", "message": "Issue produced empty content"}
|
||||
|
||||
issue_identifier = formatted.get("key", "")
|
||||
issue_title = formatted.get("title", "")
|
||||
state = formatted.get("status", "Unknown")
|
||||
comment_count = len(formatted.get("comments", []))
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
doc_meta = {
|
||||
"issue_key": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"status": state,
|
||||
"document_type": "Jira Issue",
|
||||
"connector_type": "Jira",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_meta
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
|
||||
document.title = f"{issue_identifier}: {issue_title}"
|
||||
document.content = summary_content
|
||||
document.content_hash = generate_content_hash(
|
||||
issue_content, search_space_id
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.document_metadata = {
|
||||
**(document.document_metadata or {}),
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"state": state,
|
||||
"comment_count": comment_count,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync successful for document %s (%s: %s)",
|
||||
document_id,
|
||||
issue_identifier,
|
||||
issue_title,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"KB sync failed for document %s: %s", document_id, e, exc_info=True
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
332
surfsense_backend/app/services/jira/tool_metadata_service.py
Normal file
332
surfsense_backend/app/services/jira/tool_metadata_service.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraWorkspace:
|
||||
"""Represents a Jira connector as a workspace for tool context."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "JiraWorkspace":
|
||||
return cls(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
base_url=connector.config.get("base_url", ""),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraIssue:
|
||||
"""Represents an indexed Jira issue resolved from the knowledge base."""
|
||||
|
||||
issue_id: str
|
||||
issue_identifier: str
|
||||
issue_title: str
|
||||
state: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
indexed_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "JiraIssue":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
issue_id=meta.get("issue_id", ""),
|
||||
issue_identifier=meta.get("issue_identifier", ""),
|
||||
issue_title=meta.get("issue_title", document.title),
|
||||
state=meta.get("state", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
indexed_at=meta.get("indexed_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"issue_id": self.issue_id,
|
||||
"issue_identifier": self.issue_identifier,
|
||||
"issue_title": self.issue_title,
|
||||
"state": self.state,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
"indexed_at": self.indexed_at,
|
||||
}
|
||||
|
||||
|
||||
class JiraToolMetadataService:
|
||||
"""Builds interrupt context for Jira HITL tools."""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
|
||||
"""Check if the Jira connector auth is still valid.
|
||||
|
||||
Returns True if auth is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.get_myself)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Jira connector %s health check failed: %s", connector.id, e)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
"""Return context needed to create a new Jira issue.
|
||||
|
||||
Fetches all connected Jira accounts, and for the first healthy one
|
||||
fetches projects, issue types, and priorities.
|
||||
"""
|
||||
connectors = await self._get_all_jira_connectors(search_space_id, user_id)
|
||||
if not connectors:
|
||||
return {"error": "No Jira account connected"}
|
||||
|
||||
accounts = []
|
||||
projects = []
|
||||
issue_types = []
|
||||
priorities = []
|
||||
fetched_context = False
|
||||
|
||||
for connector in connectors:
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
account_info = {
|
||||
**workspace.to_dict(),
|
||||
"auth_expired": auth_expired,
|
||||
}
|
||||
accounts.append(account_info)
|
||||
|
||||
if not auth_expired and not fetched_context:
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
raw_projects = await asyncio.to_thread(jira_client.get_projects)
|
||||
projects = [
|
||||
{"id": p.get("id"), "key": p.get("key"), "name": p.get("name")}
|
||||
for p in raw_projects
|
||||
]
|
||||
raw_types = await asyncio.to_thread(jira_client.get_issue_types)
|
||||
seen_type_names: set[str] = set()
|
||||
issue_types = []
|
||||
for t in raw_types:
|
||||
if t.get("subtask", False):
|
||||
continue
|
||||
name = t.get("name")
|
||||
if name not in seen_type_names:
|
||||
seen_type_names.add(name)
|
||||
issue_types.append({"id": t.get("id"), "name": name})
|
||||
raw_priorities = await asyncio.to_thread(jira_client.get_priorities)
|
||||
priorities = [
|
||||
{"id": p.get("id"), "name": p.get("name")}
|
||||
for p in raw_priorities
|
||||
]
|
||||
fetched_context = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Jira context for connector %s: %s",
|
||||
connector.id,
|
||||
e,
|
||||
)
|
||||
|
||||
return {
|
||||
"accounts": accounts,
|
||||
"projects": projects,
|
||||
"issue_types": issue_types,
|
||||
"priorities": priorities,
|
||||
}
|
||||
|
||||
async def get_update_context(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to update an indexed Jira issue.
|
||||
|
||||
Resolves the issue from the KB, then fetches current details from the Jira API.
|
||||
"""
|
||||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
|
||||
"Please make sure the issue is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = await self._check_account_health(connector)
|
||||
if auth_expired:
|
||||
return {
|
||||
"error": "Jira authentication has expired. Please re-authenticate.",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
issue = JiraIssue.from_document(document)
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=self._db_session, connector_id=connector.id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
issue_data = await asyncio.to_thread(jira_client.get_issue, issue.issue_id)
|
||||
formatted = jira_client.format_issue(issue_data)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "403" in error_str
|
||||
or "authentication" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Jira issue: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Jira issue: {e!s}"}
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": False},
|
||||
"issue": {
|
||||
"issue_id": formatted.get("key", issue.issue_id),
|
||||
"issue_identifier": formatted.get("key", issue.issue_identifier),
|
||||
"issue_title": formatted.get("title", issue.issue_title),
|
||||
"state": formatted.get("status", "Unknown"),
|
||||
"priority": formatted.get("priority", "Unknown"),
|
||||
"issue_type": formatted.get("issue_type", "Unknown"),
|
||||
"assignee": formatted.get("assignee"),
|
||||
"description": formatted.get("description"),
|
||||
"project": formatted.get("project", ""),
|
||||
"document_id": issue.document_id,
|
||||
"indexed_at": issue.indexed_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_deletion_context(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> dict:
|
||||
"""Return context needed to delete a Jira issue (KB metadata only, no API call)."""
|
||||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
|
||||
"Please make sure the issue is indexed in your knowledge base."
|
||||
}
|
||||
|
||||
connector = await self._get_connector_for_document(document, user_id)
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
auth_expired = connector.config.get("auth_expired", False)
|
||||
workspace = JiraWorkspace.from_connector(connector)
|
||||
issue = JiraIssue.from_document(document)
|
||||
|
||||
return {
|
||||
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
|
||||
"issue": issue.to_dict(),
|
||||
}
|
||||
|
||||
async def _resolve_issue(
|
||||
self, search_space_id: int, user_id: str, issue_ref: str
|
||||
) -> Document | None:
|
||||
"""Resolve an issue from KB: issue_identifier -> issue_title -> document.title."""
|
||||
ref_lower = issue_ref.lower()
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.JIRA_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
or_(
|
||||
func.lower(
|
||||
Document.document_metadata.op("->>")("issue_identifier")
|
||||
)
|
||||
== ref_lower,
|
||||
func.lower(Document.document_metadata.op("->>")("issue_title"))
|
||||
== ref_lower,
|
||||
func.lower(Document.title) == ref_lower,
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_all_jira_connectors(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[SearchSourceConnector]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def _get_connector_for_document(
|
||||
self, document: Document, user_id: str
|
||||
) -> SearchSourceConnector | None:
|
||||
if not document.connector_id:
|
||||
return None
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -4,29 +4,174 @@ from datetime import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document
|
||||
from app.db import Document, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearKBSyncService:
|
||||
"""Re-indexes a single Linear issue document after a successful update.
|
||||
"""Syncs Linear issue documents to the knowledge base after HITL actions.
|
||||
|
||||
Mirrors the indexer's Phase-2 logic exactly: fetch fresh issue content,
|
||||
run generate_document_summary, create_document_chunks, then update the
|
||||
document row in the knowledge base.
|
||||
Provides sync_after_create (new issue) and sync_after_update (existing issue).
|
||||
Both mirror the indexer's Phase-2 logic: generate summary, create chunks,
|
||||
then persist the document row.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
issue_id: str,
|
||||
issue_identifier: str,
|
||||
issue_title: str,
|
||||
issue_url: str | None,
|
||||
description: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Linear issue %s already exists (doc_id=%s), skipping",
|
||||
issue_identifier,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (description or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Linear Issue {issue_identifier}: {issue_title}"
|
||||
|
||||
issue_content = (
|
||||
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(issue_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Linear issue %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
issue_identifier,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"issue_id": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"document_type": "Linear Issue",
|
||||
"connector_type": "Linear",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
issue_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = (
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(issue_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=f"{issue_identifier}: {issue_title}",
|
||||
document_type=DocumentType.LINEAR_CONNECTOR,
|
||||
document_metadata={
|
||||
"issue_id": issue_id,
|
||||
"issue_identifier": issue_identifier,
|
||||
"issue_title": issue_title,
|
||||
"issue_url": issue_url,
|
||||
"source_connector": "linear",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, issue=%s, chunks=%d",
|
||||
document.id,
|
||||
issue_identifier,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for issue %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for issue %s: %s",
|
||||
issue_identifier,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def sync_after_update(
|
||||
self,
|
||||
document_id: int,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import (
|
||||
|
|
@ -12,6 +14,8 @@ from app.db import (
|
|||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearWorkspace:
|
||||
|
|
@ -109,7 +113,34 @@ class LinearToolMetadataService:
|
|||
priorities = await self._fetch_priority_values(linear_client)
|
||||
teams = await self._fetch_teams_context(linear_client)
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to fetch Linear context: {e!s}"}
|
||||
logger.warning(
|
||||
"Linear connector %s (%s) auth failed, flagging as expired: %s",
|
||||
connector.id,
|
||||
workspace.name,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
workspaces.append(
|
||||
{
|
||||
"id": workspace.id,
|
||||
"name": workspace.name,
|
||||
"organization_name": workspace.organization_name,
|
||||
"teams": [],
|
||||
"priorities": [],
|
||||
"auth_expired": True,
|
||||
}
|
||||
)
|
||||
continue
|
||||
workspaces.append(
|
||||
{
|
||||
"id": workspace.id,
|
||||
|
|
@ -117,6 +148,7 @@ class LinearToolMetadataService:
|
|||
"organization_name": workspace.organization_name,
|
||||
"teams": teams,
|
||||
"priorities": priorities,
|
||||
"auth_expired": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +169,8 @@ class LinearToolMetadataService:
|
|||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
|
||||
"or (3) the title or identifier is different."
|
||||
}
|
||||
|
||||
|
|
@ -157,6 +189,17 @@ class LinearToolMetadataService:
|
|||
priorities = await self._fetch_priority_values(linear_client)
|
||||
issue_api = await self._fetch_issue_context(linear_client, issue.id)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"401" in error_str
|
||||
or "authentication" in error_str
|
||||
or "re-authenticate" in error_str
|
||||
):
|
||||
return {
|
||||
"error": f"Failed to fetch Linear issue context: {e!s}",
|
||||
"auth_expired": True,
|
||||
"connector_id": connector.id,
|
||||
}
|
||||
return {"error": f"Failed to fetch Linear issue context: {e!s}"}
|
||||
|
||||
if not issue_api:
|
||||
|
|
@ -210,8 +253,8 @@ class LinearToolMetadataService:
|
|||
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
|
||||
if not document:
|
||||
return {
|
||||
"error": f"Issue '{issue_ref}' not found in your indexed Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"error": f"Issue '{issue_ref}' not found in your synced Linear issues. "
|
||||
"This could mean: (1) the issue doesn't exist, (2) it hasn't been synced yet, "
|
||||
"or (3) the title or identifier is different."
|
||||
}
|
||||
|
||||
|
|
@ -319,6 +362,7 @@ class LinearToolMetadataService:
|
|||
),
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue