mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge pull request #713 from MODSetter/dev
Feat: Add Chat Comments with Mentions, MCP Connectors, and Electric SQL Integration
This commit is contained in:
commit
4793e54b78
164 changed files with 15560 additions and 5455 deletions
12
.env.example
12
.env.example
|
|
@ -9,7 +9,6 @@ FRONTEND_PORT=3000
|
|||
NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 (Default: http://localhost:8000)
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE (Default: LOCAL)
|
||||
NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING (Default: DOCLING)
|
||||
|
||||
# Backend Configuration
|
||||
BACKEND_PORT=8000
|
||||
|
||||
|
|
@ -19,6 +18,17 @@ POSTGRES_PASSWORD=postgres
|
|||
POSTGRES_DB=surfsense
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# Electric-SQL Configuration
|
||||
ELECTRIC_PORT=5133
|
||||
# PostgreSQL host for Electric connection
|
||||
# - 'db' for Docker PostgreSQL (service name in docker-compose)
|
||||
# - 'host.docker.internal' for local PostgreSQL (recommended when Electric runs in Docker)
|
||||
# Note: host.docker.internal works on Docker Desktop (Mac/Windows) and can be enabled on Linux
|
||||
POSTGRES_HOST=db
|
||||
ELECTRIC_DB_USER=electric
|
||||
ELECTRIC_DB_PASSWORD=electric_password
|
||||
NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
|
||||
|
||||
# pgAdmin Configuration
|
||||
PGADMIN_PORT=5050
|
||||
PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
|
|
|
|||
87
.github/workflows/docker_build.yaml
vendored
87
.github/workflows/docker_build.yaml
vendored
|
|
@ -3,17 +3,8 @@ name: Build and Push Docker Image
|
|||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
bump_type:
|
||||
description: 'Version bump type (patch, minor, major)'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
branch:
|
||||
description: 'Branch to tag (leave empty for default branch)'
|
||||
description: 'Branch to build from (leave empty for default branch)'
|
||||
required: false
|
||||
default: ''
|
||||
|
||||
|
|
@ -34,55 +25,37 @@ jobs:
|
|||
ref: ${{ github.event.inputs.branch }}
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get latest SemVer tag and calculate next version
|
||||
- name: Read app version and calculate next Docker build version
|
||||
id: tag_version
|
||||
run: |
|
||||
git fetch --tags
|
||||
LATEST_TAG=$(git tag --list 'v[0-9]*.[0-9]*.[0-9]*' --sort='v:refname' | tail -n 1)
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No previous SemVer tag found. Starting with v0.1.0"
|
||||
case "${{ github.event.inputs.bump_type }}" in
|
||||
patch|minor)
|
||||
NEXT_VERSION="v0.1.0"
|
||||
;;
|
||||
major)
|
||||
NEXT_VERSION="v1.0.0"
|
||||
;;
|
||||
*)
|
||||
echo "Invalid bump type: ${{ github.event.inputs.bump_type }}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
else
|
||||
echo "Latest tag found: $LATEST_TAG"
|
||||
VERSION=${LATEST_TAG#v}
|
||||
MAJOR=$(echo $VERSION | cut -d. -f1)
|
||||
MINOR=$(echo $VERSION | cut -d. -f2)
|
||||
PATCH=$(echo $VERSION | cut -d. -f3)
|
||||
|
||||
case "${{ github.event.inputs.bump_type }}" in
|
||||
patch)
|
||||
PATCH=$((PATCH + 1))
|
||||
;;
|
||||
minor)
|
||||
MINOR=$((MINOR + 1))
|
||||
PATCH=0
|
||||
;;
|
||||
major)
|
||||
MAJOR=$((MAJOR + 1))
|
||||
MINOR=0
|
||||
PATCH=0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid bump type: ${{ github.event.inputs.bump_type }}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
NEXT_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
|
||||
# Read version from pyproject.toml
|
||||
APP_VERSION=$(grep -E '^version = ' surfsense_backend/pyproject.toml | sed 's/version = "\(.*\)"/\1/')
|
||||
echo "App version from pyproject.toml: $APP_VERSION"
|
||||
|
||||
if [ -z "$APP_VERSION" ]; then
|
||||
echo "Error: Could not read version from surfsense_backend/pyproject.toml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Calculated next version: $NEXT_VERSION"
|
||||
|
||||
# Fetch all tags
|
||||
git fetch --tags
|
||||
|
||||
# Find the latest docker build tag for this app version (format: APP_VERSION.BUILD_NUMBER)
|
||||
# Tags follow pattern: 0.0.11.1, 0.0.11.2, etc.
|
||||
LATEST_BUILD_TAG=$(git tag --list "${APP_VERSION}.*" --sort='-v:refname' | head -n 1)
|
||||
|
||||
if [ -z "$LATEST_BUILD_TAG" ]; then
|
||||
echo "No previous Docker build tag found for version ${APP_VERSION}. Starting with ${APP_VERSION}.1"
|
||||
NEXT_VERSION="${APP_VERSION}.1"
|
||||
else
|
||||
echo "Latest Docker build tag found: $LATEST_BUILD_TAG"
|
||||
# Extract the build number (4th component)
|
||||
BUILD_NUMBER=$(echo "$LATEST_BUILD_TAG" | rev | cut -d. -f1 | rev)
|
||||
NEXT_BUILD=$((BUILD_NUMBER + 1))
|
||||
NEXT_VERSION="${APP_VERSION}.${NEXT_BUILD}"
|
||||
fi
|
||||
|
||||
echo "Calculated next Docker version: $NEXT_VERSION"
|
||||
echo "next_version=$NEXT_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create and Push Tag
|
||||
|
|
@ -94,7 +67,7 @@ jobs:
|
|||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "Tagging commit $COMMIT_SHA with $NEXT_TAG"
|
||||
|
||||
git tag -a "$NEXT_TAG" -m "Release $NEXT_TAG"
|
||||
git tag -a "$NEXT_TAG" -m "Docker build $NEXT_TAG"
|
||||
echo "Pushing tag $NEXT_TAG to origin"
|
||||
git push origin "$NEXT_TAG"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
# SurfSense All-in-One Docker Image
|
||||
# This image bundles PostgreSQL+pgvector, Redis, Backend, and Frontend
|
||||
# Usage: docker run -d -p 3000:3000 -p 8000:8000 -v surfsense-data:/data --name surfsense ghcr.io/modsetter/surfsense:latest
|
||||
# This image bundles PostgreSQL+pgvector, Redis, Electric SQL, Backend, and Frontend
|
||||
# Usage: docker run -d -p 3000:3000 -p 8000:8000 -p 5133:5133 -v surfsense-data:/data --name surfsense ghcr.io/modsetter/surfsense:latest
|
||||
#
|
||||
# Included Services (all run locally by default):
|
||||
# - PostgreSQL 14 + pgvector (vector database)
|
||||
# - Redis (task queue)
|
||||
# - Electric SQL (real-time sync)
|
||||
# - Docling (document processing, CPU-only, OCR disabled)
|
||||
# - Kokoro TTS (local text-to-speech for podcasts)
|
||||
# - Faster-Whisper (local speech-to-text for audio files)
|
||||
|
|
@ -14,7 +15,12 @@
|
|||
# will be available in the future for faster AI inference.
|
||||
|
||||
# ====================
|
||||
# Stage 1: Build Frontend
|
||||
# Stage 1: Get Electric SQL Binary
|
||||
# ====================
|
||||
FROM electricsql/electric:latest AS electric-builder
|
||||
|
||||
# ====================
|
||||
# Stage 2: Build Frontend
|
||||
# ====================
|
||||
FROM node:20-alpine AS frontend-builder
|
||||
|
||||
|
|
@ -42,12 +48,14 @@ RUN pnpm fumadocs-mdx
|
|||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__
|
||||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__
|
||||
ENV NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__
|
||||
ENV NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__
|
||||
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__
|
||||
|
||||
# Build
|
||||
RUN pnpm run build
|
||||
|
||||
# ====================
|
||||
# Stage 2: Runtime Image
|
||||
# Stage 3: Runtime Image
|
||||
# ====================
|
||||
FROM ubuntu:22.04 AS runtime
|
||||
|
||||
|
|
@ -167,6 +175,11 @@ COPY --from=frontend-builder /app/public ./public
|
|||
|
||||
COPY surfsense_web/content/docs /app/surfsense_web/content/docs
|
||||
|
||||
# ====================
|
||||
# Copy Electric SQL Release
|
||||
# ====================
|
||||
COPY --from=electric-builder /app /app/electric-release
|
||||
|
||||
# ====================
|
||||
# Setup Backend
|
||||
# ====================
|
||||
|
|
@ -238,11 +251,22 @@ ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
|
|||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
ENV NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
|
||||
# Electric SQL configuration (ELECTRIC_DATABASE_URL is built dynamically by entrypoint from these values)
|
||||
ENV ELECTRIC_DB_USER=electric
|
||||
ENV ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Note: ELECTRIC_DATABASE_URL is NOT set here - entrypoint builds it dynamically from ELECTRIC_DB_USER/PASSWORD
|
||||
ENV ELECTRIC_INSECURE=true
|
||||
ENV ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
ENV ELECTRIC_PORT=5133
|
||||
ENV PORT=5133
|
||||
ENV NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
|
||||
ENV NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# Data volume
|
||||
VOLUME ["/data"]
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 3000 8000
|
||||
# Expose ports (Frontend: 3000, Backend: 8000, Electric: 5133)
|
||||
EXPOSE 3000 8000 5133
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
|
||||
|
|
|
|||
|
|
@ -7,10 +7,15 @@ services:
|
|||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./scripts/docker/postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/docker/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-surfsense}
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
|
||||
pgadmin:
|
||||
image: dpage/pgadmin4
|
||||
|
|
@ -51,11 +56,14 @@ services:
|
|||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- NEXT_FRONTEND_URL=http://frontend:3000
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
|
||||
# Run these services seperately in production
|
||||
# Run these services separately in production
|
||||
# celery_worker:
|
||||
# build: ./surfsense_backend
|
||||
# # image: ghcr.io/modsetter/surfsense_backend:latest
|
||||
|
|
@ -110,6 +118,23 @@ services:
|
|||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:latest
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
environment:
|
||||
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${POSTGRES_HOST:-db}:${POSTGRES_PORT:-5432}/${POSTGRES_DB:-surfsense}?sslmode=disable}
|
||||
- ELECTRIC_INSECURE=true
|
||||
- ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
restart: unless-stopped
|
||||
# depends_on:
|
||||
# - db
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ./surfsense_web
|
||||
|
|
@ -122,8 +147,12 @@ services:
|
|||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
env_file:
|
||||
- ./surfsense_web/.env
|
||||
environment:
|
||||
- NEXT_PUBLIC_ELECTRIC_URL=${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
|
||||
- NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
depends_on:
|
||||
- backend
|
||||
- electric
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,31 @@ if [ -z "$STT_SERVICE" ]; then
|
|||
echo "✅ Using default STT_SERVICE: local/base"
|
||||
fi
|
||||
|
||||
# ================================================
|
||||
# Set Electric SQL configuration
|
||||
# ================================================
|
||||
export ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
export ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
if [ -z "$ELECTRIC_DATABASE_URL" ]; then
|
||||
export ELECTRIC_DATABASE_URL="postgresql://${ELECTRIC_DB_USER}:${ELECTRIC_DB_PASSWORD}@localhost:5432/${POSTGRES_DB:-surfsense}?sslmode=disable"
|
||||
echo "✅ Electric SQL URL configured dynamically"
|
||||
else
|
||||
# Ensure sslmode=disable is in the URL if not already present
|
||||
if [[ "$ELECTRIC_DATABASE_URL" != *"sslmode="* ]]; then
|
||||
# Add sslmode=disable (handle both cases: with or without existing query params)
|
||||
if [[ "$ELECTRIC_DATABASE_URL" == *"?"* ]]; then
|
||||
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}&sslmode=disable"
|
||||
else
|
||||
export ELECTRIC_DATABASE_URL="${ELECTRIC_DATABASE_URL}?sslmode=disable"
|
||||
fi
|
||||
fi
|
||||
echo "✅ Electric SQL URL configured from environment"
|
||||
fi
|
||||
|
||||
# Set Electric SQL port
|
||||
export ELECTRIC_PORT="${ELECTRIC_PORT:-5133}"
|
||||
export PORT="${ELECTRIC_PORT}"
|
||||
|
||||
# ================================================
|
||||
# Initialize PostgreSQL if needed
|
||||
# ================================================
|
||||
|
|
@ -60,6 +85,11 @@ if [ ! -f /data/postgres/PG_VERSION ]; then
|
|||
echo "local all all trust" >> /data/postgres/pg_hba.conf
|
||||
echo "listen_addresses='*'" >> /data/postgres/postgresql.conf
|
||||
|
||||
# Enable logical replication for Electric SQL
|
||||
echo "wal_level = logical" >> /data/postgres/postgresql.conf
|
||||
echo "max_replication_slots = 10" >> /data/postgres/postgresql.conf
|
||||
echo "max_wal_senders = 10" >> /data/postgres/postgresql.conf
|
||||
|
||||
# Start PostgreSQL temporarily to create database and user
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres -l /tmp/postgres_init.log start"
|
||||
|
||||
|
|
@ -73,6 +103,35 @@ if [ ! -f /data/postgres/PG_VERSION ]; then
|
|||
# Enable pgvector extension
|
||||
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
|
||||
|
||||
# Create Electric SQL replication user (idempotent - uses IF NOT EXISTS)
|
||||
echo "📡 Creating Electric SQL replication user..."
|
||||
su - postgres -c "psql -d ${POSTGRES_DB:-surfsense} <<-EOSQL
|
||||
DO \\\$\\\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '${ELECTRIC_DB_USER}') THEN
|
||||
CREATE USER ${ELECTRIC_DB_USER} WITH REPLICATION PASSWORD '${ELECTRIC_DB_PASSWORD}';
|
||||
END IF;
|
||||
END
|
||||
\\\$\\\$;
|
||||
|
||||
GRANT CONNECT ON DATABASE ${POSTGRES_DB:-surfsense} TO ${ELECTRIC_DB_USER};
|
||||
GRANT USAGE ON SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO ${ELECTRIC_DB_USER};
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO ${ELECTRIC_DB_USER};
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO ${ELECTRIC_DB_USER};
|
||||
|
||||
-- Create the publication for Electric SQL (if not exists)
|
||||
DO \\\$\\\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
\\\$\\\$;
|
||||
EOSQL"
|
||||
echo "✅ Electric SQL user '${ELECTRIC_DB_USER}' created"
|
||||
|
||||
# Stop temporary PostgreSQL
|
||||
su - postgres -c "/usr/lib/postgresql/14/bin/pg_ctl -D /data/postgres stop"
|
||||
|
||||
|
|
@ -107,18 +166,23 @@ echo "🔧 Applying runtime environment configuration..."
|
|||
NEXT_PUBLIC_FASTAPI_BACKEND_URL="${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}"
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE="${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}"
|
||||
NEXT_PUBLIC_ETL_SERVICE="${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}"
|
||||
NEXT_PUBLIC_ELECTRIC_URL="${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE="${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}"
|
||||
|
||||
# Replace placeholders in all JS files
|
||||
find /app/frontend -type f \( -name "*.js" -o -name "*.json" \) -exec sed -i \
|
||||
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_URL__|${NEXT_PUBLIC_FASTAPI_BACKEND_URL}|g" \
|
||||
-e "s|__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__|${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ETL_SERVICE__|${NEXT_PUBLIC_ETL_SERVICE}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ELECTRIC_URL__|${NEXT_PUBLIC_ELECTRIC_URL}|g" \
|
||||
-e "s|__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__|${NEXT_PUBLIC_ELECTRIC_AUTH_MODE}|g" \
|
||||
{} +
|
||||
|
||||
echo "✅ Environment configuration applied"
|
||||
echo " Backend URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
|
||||
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
||||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " Backend URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
|
||||
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
||||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL}"
|
||||
|
||||
# ================================================
|
||||
# Run database migrations
|
||||
|
|
@ -161,6 +225,7 @@ echo "==========================================="
|
|||
echo " Frontend URL: http://localhost:3000"
|
||||
echo " Backend API: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}"
|
||||
echo " API Docs: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL}/docs"
|
||||
echo " Electric URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}"
|
||||
echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
||||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " TTS Service: ${TTS_SERVICE}"
|
||||
|
|
|
|||
56
scripts/docker/init-electric-user.sh
Executable file
56
scripts/docker/init-electric-user.sh
Executable file
|
|
@ -0,0 +1,56 @@
|
|||
#!/bin/sh
|
||||
# ============================================================================
|
||||
# Electric SQL User Initialization Script (docker-compose only)
|
||||
# ============================================================================
|
||||
# This script is ONLY used when running via docker-compose.
|
||||
#
|
||||
# How it works:
|
||||
# - docker-compose.yml mounts this script into the PostgreSQL container's
|
||||
# /docker-entrypoint-initdb.d/ directory
|
||||
# - PostgreSQL automatically executes scripts in that directory on first
|
||||
# container initialization
|
||||
#
|
||||
# For local PostgreSQL users (non-Docker), this script is NOT used.
|
||||
# Instead, the Electric user is created by Alembic migration 66
|
||||
# (66_add_notifications_table_and_electric_replication.py).
|
||||
#
|
||||
# Both approaches are idempotent (use IF NOT EXISTS), so running both
|
||||
# will not cause conflicts.
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Use environment variables with defaults
|
||||
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
|
||||
echo "Creating Electric SQL replication user: $ELECTRIC_DB_USER"
|
||||
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '$ELECTRIC_DB_USER') THEN
|
||||
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
|
||||
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT CREATE ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
|
||||
-- Create the publication for Electric SQL (if not exists)
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
EOSQL
|
||||
|
||||
echo "Electric SQL user '$ELECTRIC_DB_USER' and publication created successfully"
|
||||
|
|
@ -9,6 +9,10 @@ POSTGRES_USER=${POSTGRES_USER:-surfsense}
|
|||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-surfsense}
|
||||
POSTGRES_DB=${POSTGRES_DB:-surfsense}
|
||||
|
||||
# Electric SQL user credentials (configurable)
|
||||
ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
|
||||
echo "Initializing PostgreSQL..."
|
||||
|
||||
# Check if PostgreSQL is already initialized
|
||||
|
|
@ -23,8 +27,18 @@ fi
|
|||
# Configure PostgreSQL
|
||||
cat >> "$PGDATA/postgresql.conf" << EOF
|
||||
listen_addresses = '*'
|
||||
max_connections = 100
|
||||
shared_buffers = 128MB
|
||||
max_connections = 200
|
||||
shared_buffers = 256MB
|
||||
|
||||
# Enable logical replication (required for Electric SQL)
|
||||
wal_level = logical
|
||||
max_replication_slots = 10
|
||||
max_wal_senders = 10
|
||||
|
||||
# Performance settings
|
||||
checkpoint_timeout = 10min
|
||||
max_wal_size = 1GB
|
||||
min_wal_size = 80MB
|
||||
EOF
|
||||
|
||||
cat >> "$PGDATA/pg_hba.conf" << EOF
|
||||
|
|
@ -45,6 +59,15 @@ CREATE USER $POSTGRES_USER WITH PASSWORD '$POSTGRES_PASSWORD' SUPERUSER;
|
|||
CREATE DATABASE $POSTGRES_DB OWNER $POSTGRES_USER;
|
||||
\c $POSTGRES_DB
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Create Electric SQL replication user
|
||||
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
|
||||
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
EOF
|
||||
|
||||
echo "PostgreSQL initialized successfully."
|
||||
|
|
|
|||
20
scripts/docker/postgresql.conf
Normal file
20
scripts/docker/postgresql.conf
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# PostgreSQL configuration for Electric SQL
|
||||
# This file is mounted into the PostgreSQL container
|
||||
|
||||
listen_addresses = '*'
|
||||
max_connections = 200
|
||||
shared_buffers = 256MB
|
||||
|
||||
# Enable logical replication (required for Electric SQL)
|
||||
wal_level = logical
|
||||
max_replication_slots = 10
|
||||
max_wal_senders = 10
|
||||
|
||||
# Performance settings
|
||||
checkpoint_timeout = 10min
|
||||
max_wal_size = 1GB
|
||||
min_wal_size = 80MB
|
||||
|
||||
# Logging (optional, for debugging)
|
||||
# log_statement = 'all'
|
||||
# log_replication_commands = on
|
||||
|
|
@ -85,6 +85,20 @@ stderr_logfile=/dev/stderr
|
|||
stderr_logfile_maxbytes=0
|
||||
environment=PYTHONPATH="/app/backend"
|
||||
|
||||
# Electric SQL (real-time sync)
|
||||
[program:electric]
|
||||
command=/app/electric-release/bin/entrypoint start
|
||||
autostart=true
|
||||
autorestart=true
|
||||
priority=25
|
||||
startsecs=10
|
||||
startretries=3
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
environment=DATABASE_URL="%(ENV_ELECTRIC_DATABASE_URL)s",ELECTRIC_INSECURE="%(ENV_ELECTRIC_INSECURE)s",ELECTRIC_WRITE_TO_PG_MODE="%(ENV_ELECTRIC_WRITE_TO_PG_MODE)s",RELEASE_COOKIE="surfsense_electric_cookie",PORT="%(ENV_ELECTRIC_PORT)s"
|
||||
|
||||
# Frontend
|
||||
[program:frontend]
|
||||
command=node server.js
|
||||
|
|
@ -102,6 +116,6 @@ environment=NODE_ENV="production",PORT="3000",HOSTNAME="0.0.0.0"
|
|||
|
||||
# Process Groups
|
||||
[group:surfsense]
|
||||
programs=postgresql,redis,backend,celery-worker,celery-beat,frontend
|
||||
programs=postgresql,redis,electric,backend,celery-worker,celery-beat,frontend
|
||||
priority=999
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,13 @@ database_url = os.getenv("DATABASE_URL")
|
|||
if database_url:
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Electric SQL user credentials - centralized configuration for migrations
|
||||
# These are used by migrations that set up Electric SQL replication
|
||||
config.set_main_option("electric_db_user", os.getenv("ELECTRIC_DB_USER", "electric"))
|
||||
config.set_main_option(
|
||||
"electric_db_password", os.getenv("ELECTRIC_DB_PASSWORD", "electric_password")
|
||||
)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
|
|
|||
|
|
@ -37,10 +37,5 @@ def upgrade() -> None:
|
|||
|
||||
def downgrade() -> None:
|
||||
"""Remove author_id column from new_chat_messages table."""
|
||||
op.execute(
|
||||
"""
|
||||
DROP INDEX IF EXISTS ix_new_chat_messages_author_id;
|
||||
ALTER TABLE new_chat_messages
|
||||
DROP COLUMN IF EXISTS author_id;
|
||||
"""
|
||||
)
|
||||
op.execute("DROP INDEX IF EXISTS ix_new_chat_messages_author_id")
|
||||
op.execute("ALTER TABLE new_chat_messages DROP COLUMN IF EXISTS author_id")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,172 @@
|
|||
"""Add notifications table and Electric SQL replication
|
||||
|
||||
Revision ID: 66
|
||||
Revises: 65
|
||||
|
||||
Creates notifications table and sets up Electric SQL replication
|
||||
(user, publication, REPLICA IDENTITY FULL) for notifications,
|
||||
search_source_connectors, and documents tables.
|
||||
|
||||
NOTE: Electric SQL user creation is idempotent (uses IF NOT EXISTS).
|
||||
- Docker deployments: user is pre-created by scripts/docker/init-electric-user.sh
|
||||
- Local PostgreSQL: user is created here during migration
|
||||
Both approaches are safe to run together without conflicts as this migraiton is idempotent
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import context, op
|
||||
|
||||
# Get Electric SQL user credentials from env.py configuration
|
||||
_config = context.config
|
||||
ELECTRIC_DB_USER = _config.get_main_option("electric_db_user", "electric")
|
||||
ELECTRIC_DB_PASSWORD = _config.get_main_option(
|
||||
"electric_db_password", "electric_password"
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "66"
|
||||
down_revision: str | None = "65"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - add notifications table and Electric SQL replication."""
|
||||
# Create notifications table
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS notifications (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
title VARCHAR(200) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes (using IF NOT EXISTS for idempotency)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_notifications_user_id ON notifications (user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_notifications_read ON notifications (read);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_notifications_created_at ON notifications (created_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_notifications_user_read ON notifications (user_id, read);"
|
||||
)
|
||||
|
||||
# =====================================================
|
||||
# Electric SQL Setup - User and Publication
|
||||
# =====================================================
|
||||
|
||||
# Create Electric SQL replication user if not exists
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '{ELECTRIC_DB_USER}') THEN
|
||||
CREATE USER {ELECTRIC_DB_USER} WITH REPLICATION PASSWORD '{ELECTRIC_DB_PASSWORD}';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Grant necessary permissions to electric user
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
DECLARE
|
||||
db_name TEXT := current_database();
|
||||
BEGIN
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO {ELECTRIC_DB_USER}', db_name);
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
op.execute(f"GRANT USAGE ON SCHEMA public TO {ELECTRIC_DB_USER};")
|
||||
op.execute(f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {ELECTRIC_DB_USER};")
|
||||
op.execute(f"GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO {ELECTRIC_DB_USER};")
|
||||
op.execute(
|
||||
f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO {ELECTRIC_DB_USER};"
|
||||
)
|
||||
op.execute(
|
||||
f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO {ELECTRIC_DB_USER};"
|
||||
)
|
||||
|
||||
# Create the publication if not exists
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# =====================================================
|
||||
# Electric SQL Setup - Table Configuration
|
||||
# =====================================================
|
||||
|
||||
# Set REPLICA IDENTITY FULL (required by Electric SQL for replication)
|
||||
op.execute("ALTER TABLE notifications REPLICA IDENTITY FULL;")
|
||||
op.execute("ALTER TABLE search_source_connectors REPLICA IDENTITY FULL;")
|
||||
op.execute("ALTER TABLE documents REPLICA IDENTITY FULL;")
|
||||
|
||||
# Add tables to Electric SQL publication for replication
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Add notifications if not already added
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_publication_tables
|
||||
WHERE pubname = 'electric_publication_default'
|
||||
AND tablename = 'notifications'
|
||||
) THEN
|
||||
ALTER PUBLICATION electric_publication_default ADD TABLE notifications;
|
||||
END IF;
|
||||
|
||||
-- Add search_source_connectors if not already added
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_publication_tables
|
||||
WHERE pubname = 'electric_publication_default'
|
||||
AND tablename = 'search_source_connectors'
|
||||
) THEN
|
||||
ALTER PUBLICATION electric_publication_default ADD TABLE search_source_connectors;
|
||||
END IF;
|
||||
|
||||
-- Add documents if not already added
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_publication_tables
|
||||
WHERE pubname = 'electric_publication_default'
|
||||
AND tablename = 'documents'
|
||||
) THEN
|
||||
ALTER PUBLICATION electric_publication_default ADD TABLE documents;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - remove notifications table."""
|
||||
op.drop_index("ix_notifications_user_read", table_name="notifications")
|
||||
op.drop_index("ix_notifications_created_at", table_name="notifications")
|
||||
op.drop_index("ix_notifications_read", table_name="notifications")
|
||||
op.drop_index("ix_notifications_user_id", table_name="notifications")
|
||||
op.drop_table("notifications")
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
"""Add pg_trgm indexes for efficient document title search
|
||||
|
||||
Revision ID: 67
|
||||
Revises: 66
|
||||
|
||||
Adds the pg_trgm extension and GIN trigram indexes on documents.title
|
||||
to enable efficient ILIKE searches with leading wildcards (e.g., '%search_term%').
|
||||
|
||||
Indexes added:
|
||||
1. idx_documents_title_trgm - GIN trigram on title for ILIKE '%term%'
|
||||
2. idx_documents_search_space_id - B-tree on search_space_id for filtering
|
||||
3. idx_documents_search_space_updated - Composite for recent docs query (covering index)
|
||||
4. idx_surfsense_docs_title_trgm - GIN trigram on surfsense docs title
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "67"
|
||||
down_revision: str | None = "66"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add pg_trgm extension and optimized indexes for document search."""
|
||||
|
||||
# Create pg_trgm extension if not exists
|
||||
# This extension provides trigram-based text similarity functions and operators
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||
|
||||
# 1. GIN trigram index on documents.title for ILIKE '%term%' searches
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_title_trgm
|
||||
ON documents USING gin (title gin_trgm_ops);
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. B-tree index on search_space_id for fast filtering
|
||||
# (Every query filters by search_space_id first)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_search_space_id
|
||||
ON documents (search_space_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# 3. Covering index for "recent documents" query (no search term)
|
||||
# Includes id, title, document_type so PostgreSQL can do index-only scan
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated
|
||||
ON documents (search_space_id, updated_at DESC NULLS LAST)
|
||||
INCLUDE (id, title, document_type);
|
||||
"""
|
||||
)
|
||||
|
||||
# 4. GIN trigram index on surfsense_docs_documents.title
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_surfsense_docs_title_trgm
|
||||
ON surfsense_docs_documents USING gin (title gin_trgm_ops);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove all document search indexes (extension is left in place)."""
|
||||
op.execute("DROP INDEX IF EXISTS idx_surfsense_docs_title_trgm;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_documents_search_space_updated;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_documents_search_space_id;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_documents_title_trgm;")
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""Add chat_comments table for comments on AI responses
|
||||
|
||||
Revision ID: 68
|
||||
Revises: 67
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "68"
|
||||
down_revision: str | None = "67"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create chat_comments table."""
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS chat_comments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
message_id INTEGER NOT NULL REFERENCES new_chat_messages(id) ON DELETE CASCADE,
|
||||
parent_id INTEGER REFERENCES chat_comments(id) ON DELETE CASCADE,
|
||||
author_id UUID REFERENCES "user"(id) ON DELETE SET NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chat_comments_message_id ON chat_comments(message_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chat_comments_parent_id ON chat_comments(parent_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chat_comments_author_id ON chat_comments(author_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chat_comments_created_at ON chat_comments(created_at)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop chat_comments table."""
|
||||
op.execute(
|
||||
"""
|
||||
DROP TABLE IF EXISTS chat_comments;
|
||||
"""
|
||||
)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Add chat_comment_mentions table for @mentions in comments
|
||||
|
||||
Revision ID: 69
|
||||
Revises: 68
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "69"
|
||||
down_revision: str | None = "68"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create chat_comment_mentions table."""
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS chat_comment_mentions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
comment_id INTEGER NOT NULL REFERENCES chat_comments(id) ON DELETE CASCADE,
|
||||
mentioned_user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (comment_id, mentioned_user_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chat_comment_mentions_comment_id ON chat_comment_mentions(comment_id)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop chat_comment_mentions table."""
|
||||
op.execute(
|
||||
"""
|
||||
DROP TABLE IF EXISTS chat_comment_mentions;
|
||||
"""
|
||||
)
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""Add comments permissions to existing roles
|
||||
|
||||
Revision ID: 70
|
||||
Revises: 69
|
||||
Create Date: 2024-01-16
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "70"
|
||||
down_revision = "69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
connection = op.get_bind()
|
||||
|
||||
# Add comments:create to Admin, Editor, Viewer roles (if not already present)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_append(permissions, 'comments:create')
|
||||
WHERE name IN ('Admin', 'Editor', 'Viewer')
|
||||
AND NOT ('comments:create' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add comments:read to Admin, Editor, Viewer roles (if not already present)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_append(permissions, 'comments:read')
|
||||
WHERE name IN ('Admin', 'Editor', 'Viewer')
|
||||
AND NOT ('comments:read' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add comments:delete to Admin roles only (if not already present)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_append(permissions, 'comments:delete')
|
||||
WHERE name = 'Admin'
|
||||
AND NOT ('comments:delete' = ANY(permissions))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
connection = op.get_bind()
|
||||
|
||||
# Remove comments:create from Admin, Editor, Viewer roles
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(permissions, 'comments:create')
|
||||
WHERE name IN ('Admin', 'Editor', 'Viewer')
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remove comments:read from Admin, Editor, Viewer roles
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(permissions, 'comments:read')
|
||||
WHERE name IN ('Admin', 'Editor', 'Viewer')
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remove comments:delete from Admin roles only
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE search_space_roles
|
||||
SET permissions = array_remove(permissions, 'comments:delete')
|
||||
WHERE name = 'Admin'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""Add Electric SQL replication for chat_comment_mentions table
|
||||
|
||||
Revision ID: 71
|
||||
Revises: 70
|
||||
|
||||
Enables Electric SQL replication for the chat_comment_mentions table to support
|
||||
real-time live updates for mentions.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "71"
|
||||
down_revision: str | None = "70"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Enable Electric SQL replication for chat_comment_mentions table."""
|
||||
op.execute("ALTER TABLE chat_comment_mentions REPLICA IDENTITY FULL;")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_publication_tables
|
||||
WHERE pubname = 'electric_publication_default'
|
||||
AND tablename = 'chat_comment_mentions'
|
||||
) THEN
|
||||
ALTER PUBLICATION electric_publication_default ADD TABLE chat_comment_mentions;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove chat_comment_mentions from Electric SQL replication."""
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_publication_tables
|
||||
WHERE pubname = 'electric_publication_default'
|
||||
AND tablename = 'chat_comment_mentions'
|
||||
) THEN
|
||||
ALTER PUBLICATION electric_publication_default DROP TABLE chat_comment_mentions;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute("ALTER TABLE chat_comment_mentions REPLICA IDENTITY DEFAULT;")
|
||||
|
|
@ -116,47 +116,6 @@ You have access to the following tools:
|
|||
* This makes your response more visual and engaging.
|
||||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
||||
6. write_todos: Create and update a planning/todo list.
|
||||
- Args:
|
||||
- todos: List of todo items, each with:
|
||||
* content: Description of the task (required)
|
||||
* status: "pending", "in_progress", "completed", or "cancelled" (required)
|
||||
|
||||
STRICT MODE SELECTION - CHOOSE ONE:
|
||||
|
||||
[MODE A] AGENT PLAN (you will work through it)
|
||||
Use when: User asks you to explain, teach, plan, or break down a concept.
|
||||
Examples: "Explain how to set up Python", "Plan my trip", "Break down machine learning"
|
||||
Rules:
|
||||
- Create plan with first item "in_progress", rest "pending"
|
||||
- After explaining each step, call write_todos again to update progress
|
||||
- Only ONE item "in_progress" at a time
|
||||
- Mark items "completed" as you finish explaining them
|
||||
- Final call: all items "completed"
|
||||
|
||||
[MODE B] EXTERNAL TASK DISPLAY (from connectors - you CANNOT complete these)
|
||||
Use when: User asks to show/list/display tasks from Linear, Jira, ClickUp, GitHub, Airtable, Notion, or any connector.
|
||||
Examples: "Show my Linear tasks", "List Jira tickets", "Create todos from ClickUp", "Show GitHub issues"
|
||||
STRICT RULES:
|
||||
1. You CANNOT complete these tasks - only the user can in the actual tool
|
||||
2. PRESERVE original status from source - DO NOT use agent workflow
|
||||
3. Call write_todos ONCE with all tasks and their REAL statuses
|
||||
4. Provide insights/summary as TEXT after the todo list, NOT as todo items
|
||||
5. NO INTERNAL REASONING - Never expose your process. Do NOT say "Let me map...", "Converting statuses...", "Here's how I'll organize...", or explain mapping logic. Just call write_todos silently and provide insights.
|
||||
|
||||
STATUS MAPPING (apply strictly):
|
||||
- "completed" ← Done, Completed, Complete, Closed, Resolved, Fixed, Merged, Shipped, Released
|
||||
- "in_progress" ← In Progress, In Review, Testing, QA, Active, Doing, Started, Review, Working
|
||||
- "pending" ← Todo, To Do, Backlog, Open, New, Pending, Triage, Reopened, Unstarted
|
||||
- "cancelled" ← Cancelled, Canceled, Won't Fix, Duplicate, Invalid, Rejected, Archived, Obsolete
|
||||
|
||||
CONNECTOR-SPECIFIC:
|
||||
- Linear: state.name = "Done", "In Progress", "Todo", "Backlog", "Cancelled"
|
||||
- Jira: statusCategory.name = "To Do", "In Progress", "Done"
|
||||
- ClickUp: status = "complete", "in progress", "open", "closed"
|
||||
- GitHub: state = "open", "closed"; PRs also "merged"
|
||||
- Airtable/Notion: Check field values, apply mapping above
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "How do I install SurfSense?"
|
||||
|
|
@ -189,12 +148,24 @@ You have access to the following tools:
|
|||
|
||||
- User: "Check out https://dev.to/some-article"
|
||||
- Call: `link_preview(url="https://dev.to/some-article")`
|
||||
- Call: `scrape_webpage(url="https://dev.to/some-article")`
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your analysis, referencing the displayed image
|
||||
|
||||
- User: "What's this blog post about? https://example.com/blog/post"
|
||||
- Call: `link_preview(url="https://example.com/blog/post")`
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/post")`
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your analysis, referencing the displayed image
|
||||
|
||||
- User: "https://github.com/some/repo"
|
||||
- Call: `link_preview(url="https://github.com/some/repo")`
|
||||
- Call: `scrape_webpage(url="https://github.com/some/repo")`
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your analysis, referencing the displayed image
|
||||
|
||||
- User: "Show me this image: https://example.com/image.png"
|
||||
- Call: `display_image(src="https://example.com/image.png", alt="User shared image")`
|
||||
|
|
@ -210,86 +181,31 @@ You have access to the following tools:
|
|||
- The user can already see their screenshot - they don't need you to display it again.
|
||||
|
||||
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
|
||||
- Call: `link_preview(url="https://example.com/blog/ai-trends")`
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
|
||||
- After getting the content, provide a summary based on the scraped text
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide a summary based on the scraped text
|
||||
|
||||
- User: "What does this page say about machine learning? https://docs.example.com/ml-guide"
|
||||
- Call: `link_preview(url="https://docs.example.com/ml-guide")`
|
||||
- Call: `scrape_webpage(url="https://docs.example.com/ml-guide")`
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then answer the question using the extracted content
|
||||
|
||||
- User: "Summarize this blog post: https://medium.com/some-article"
|
||||
- Call: `link_preview(url="https://medium.com/some-article")`
|
||||
- Call: `scrape_webpage(url="https://medium.com/some-article")`
|
||||
- Provide a comprehensive summary of the article content
|
||||
- After getting the content, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide a comprehensive summary of the article content
|
||||
|
||||
- User: "Read this tutorial and explain it: https://example.com/ml-tutorial"
|
||||
- First: `scrape_webpage(url="https://example.com/ml-tutorial")`
|
||||
- Then, if the content contains useful diagrams/images like ``:
|
||||
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
|
||||
- Then provide your explanation, referencing the displayed image
|
||||
|
||||
[MODE A EXAMPLES] Agent Plan - you work through it:
|
||||
|
||||
- User: "Create a plan for building a user authentication system"
|
||||
- Call: `write_todos(todos=[{"content": "Design database schema for users and sessions", "status": "in_progress"}, {"content": "Implement registration and login endpoints", "status": "pending"}, {"content": "Add password reset functionality", "status": "pending"}])`
|
||||
- Then explain each step in detail as you work through them
|
||||
|
||||
- User: "Break down how to build a REST API into steps"
|
||||
- Call: `write_todos(todos=[{"content": "Design API endpoints and data models", "status": "in_progress"}, {"content": "Set up server framework and routing", "status": "pending"}, {"content": "Implement CRUD operations", "status": "pending"}, {"content": "Add authentication and error handling", "status": "pending"}])`
|
||||
- Then provide detailed explanations for each step
|
||||
|
||||
- User: "Help me plan my trip to Japan"
|
||||
- Call: `write_todos(todos=[{"content": "Research best time to visit and book flights", "status": "in_progress"}, {"content": "Plan itinerary for cities to visit", "status": "pending"}, {"content": "Book accommodations", "status": "pending"}, {"content": "Prepare travel documents and currency", "status": "pending"}])`
|
||||
- Then provide travel preparation guidance
|
||||
|
||||
- COMPLETE WORKFLOW EXAMPLE - User: "Explain how to set up a Python project"
|
||||
- STEP 1 (Create initial plan):
|
||||
Call: `write_todos(todos=[{"content": "Set up virtual environment", "status": "in_progress"}, {"content": "Create project structure", "status": "pending"}, {"content": "Configure dependencies", "status": "pending"}])`
|
||||
Then explain virtual environment setup in detail...
|
||||
- STEP 2 (After explaining virtual environments, update progress):
|
||||
Call: `write_todos(todos=[{"content": "Set up virtual environment", "status": "completed"}, {"content": "Create project structure", "status": "in_progress"}, {"content": "Configure dependencies", "status": "pending"}])`
|
||||
Then explain project structure in detail...
|
||||
- STEP 3 (After explaining project structure, update progress):
|
||||
Call: `write_todos(todos=[{"content": "Set up virtual environment", "status": "completed"}, {"content": "Create project structure", "status": "completed"}, {"content": "Configure dependencies", "status": "in_progress"}])`
|
||||
Then explain dependency configuration in detail...
|
||||
- STEP 4 (After completing all explanations, mark all done):
|
||||
Call: `write_todos(todos=[{"content": "Set up virtual environment", "status": "completed"}, {"content": "Create project structure", "status": "completed"}, {"content": "Configure dependencies", "status": "completed"}])`
|
||||
Provide final summary
|
||||
|
||||
[MODE B EXAMPLES] External Tasks - preserve original status, you CANNOT complete:
|
||||
|
||||
- User: "Show my Linear tasks" or "Create todos for Linear tasks"
|
||||
- First search: `search_knowledge_base(query="Linear tasks issues", connectors_to_search=["LINEAR_CONNECTOR"])`
|
||||
- Then call write_todos ONCE with ORIGINAL statuses preserved:
|
||||
Call: `write_todos(todos=[
|
||||
{"content": "SUR-21: Add refresh button in manage documents page", "status": "completed"},
|
||||
{"content": "SUR-22: Logs page not accessible in docker", "status": "completed"},
|
||||
{"content": "SUR-27: Add Google Drive connector", "status": "in_progress"},
|
||||
{"content": "SUR-28: Logs page should show all logs", "status": "pending"}
|
||||
])`
|
||||
- Then provide INSIGHTS as text (NOT as todos):
|
||||
"You have 2 completed, 1 in progress, and 1 pending task. SUR-27 (Google Drive connector) is currently active. Consider prioritizing SUR-28 next."
|
||||
|
||||
- User: "List my Jira tickets"
|
||||
- First search: `search_knowledge_base(query="Jira tickets issues", connectors_to_search=["JIRA_CONNECTOR"])`
|
||||
- Map Jira statuses: "Done" → completed, "In Progress"/"In Review" → in_progress, "To Do" → pending
|
||||
- Call write_todos ONCE with mapped statuses
|
||||
- Provide summary as text after
|
||||
|
||||
- User: "Show ClickUp tasks"
|
||||
- First search: `search_knowledge_base(query="ClickUp tasks", connectors_to_search=["CLICKUP_CONNECTOR"])`
|
||||
- Map: "complete"/"closed" → completed, "in progress" → in_progress, "open" → pending
|
||||
- Call write_todos ONCE, then provide insights as text
|
||||
|
||||
- User: "Show my GitHub issues"
|
||||
- First search: `search_knowledge_base(query="GitHub issues", connectors_to_search=["GITHUB_CONNECTOR"])`
|
||||
- Map: "closed"/"merged" → completed, "open" → pending
|
||||
- Call write_todos ONCE, then summarize as text
|
||||
|
||||
CRITICAL FOR MODE B:
|
||||
- NEVER use the "first item in_progress, rest pending" pattern for external tasks
|
||||
- NEVER pretend you will complete external tasks - be honest that only the user can
|
||||
- ALWAYS preserve the actual status from the source system
|
||||
- ALWAYS provide insights/summaries as regular text, not as todo items
|
||||
</tool_call_examples>
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
"""MCP Client Wrapper.
|
||||
|
||||
This module provides a client for communicating with MCP servers via stdio transport.
|
||||
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
|
||||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -11,9 +12,15 @@ from typing import Any
|
|||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 1.0 # seconds
|
||||
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
|
@ -35,44 +42,86 @@ class MCPClient:
|
|||
self.session: ClientSession | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
|
||||
Yields:
|
||||
ClientSession: Active MCP session for making requests
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all connection attempts fail
|
||||
|
||||
"""
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
last_error = None
|
||||
delay = RETRY_DELAY
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
"Connected to MCP server on attempt %d: %s %s",
|
||||
attempt + 1,
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
yield session
|
||||
return # Success, exit retry loop
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
yield session
|
||||
await asyncio.sleep(delay)
|
||||
delay *= RETRY_BACKOFF # Exponential backoff
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to connect to MCP server after %d attempts: %s",
|
||||
max_retries,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.session = None
|
||||
logger.info("Disconnected from MCP server: %s", self.command)
|
||||
# All retries exhausted
|
||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_error
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List all tools available from the MCP server.
|
||||
|
|
@ -174,7 +223,7 @@ class MCPClient:
|
|||
async def test_mcp_connection(
|
||||
command: str, args: list[str], env: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server and fetch available tools.
|
||||
"""Test connection to an MCP server via stdio and fetch available tools.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server
|
||||
|
|
@ -201,3 +250,61 @@ async def test_mcp_connection(
|
|||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
|
||||
async def test_mcp_http_connection(
|
||||
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server via HTTP and fetch available tools.
|
||||
|
||||
Args:
|
||||
url: URL of the MCP server
|
||||
headers: Optional HTTP headers for authentication
|
||||
transport: Transport type ("streamable-http", "http", or "sse")
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
|
||||
)
|
||||
|
||||
# Use streamable HTTP client for all HTTP-based transports
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"HTTP MCP connection successful. Found %d tools.", len(tools)
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@
|
|||
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
||||
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
||||
|
||||
Supports both transport types:
|
||||
- stdio: Local process-based MCP servers (command, args, env)
|
||||
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
"""
|
||||
|
||||
|
|
@ -10,6 +14,8 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -65,11 +71,11 @@ def _create_dynamic_input_model_from_schema(
|
|||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition(
|
||||
async def _create_mcp_tool_from_definition_stdio(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition.
|
||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
|
|
@ -90,16 +96,22 @@ async def _create_mcp_tool_from_definition(
|
|||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client."""
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
try:
|
||||
# Connect to server and call tool
|
||||
# Connect to server and call tool (connect has built-in retry logic)
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
# Connection failures after all retries
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
||||
# Tool execution or other errors
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
|
|
@ -110,13 +122,239 @@ async def _create_mcp_tool_from_definition(
|
|||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema},
|
||||
metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool: '{tool_name}'")
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition_http(
|
||||
tool_def: dict[str, Any],
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
url: URL of the MCP server
|
||||
headers: HTTP headers for authentication
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool
|
||||
response = await session.call_tool(tool_name, arguments=kwargs)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(
|
||||
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
|
||||
)
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_http_tool_call,
|
||||
args_schema=input_model,
|
||||
metadata={
|
||||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def _load_stdio_mcp_tools(
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from a stdio-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with command, args, env
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
|
||||
f"'{command}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def _load_http_mcp_tools(
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with url, headers
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required url field
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate headers field (must be dict if present)
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Connect and discover tools via HTTP
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tool_definitions = []
|
||||
for tool in response.tools:
|
||||
tool_definitions.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
||||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(
|
||||
tool_def, url, headers
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create HTTP tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -124,6 +362,7 @@ async def load_mcp_tools(
|
|||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
|
||||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
|
@ -146,48 +385,36 @@ async def load_mcp_tools(
|
|||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Extract server config
|
||||
# Early validation: Extract and validate connector config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
# Validate server_config exists and is a dict
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} missing command, skipping"
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
# Determine transport type
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from MCP server "
|
||||
f"'{command}' (connector {connector.id})"
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
# HTTP-based MCP server
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
else:
|
||||
# stdio-based MCP server (default)
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(
|
||||
tool_def, mcp_client
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}",
|
||||
)
|
||||
tools.extend(connector_tools)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ async def get_changes(
|
|||
params = {
|
||||
"pageToken": page_token,
|
||||
"pageSize": 100,
|
||||
"fields": "nextPageToken, newStartPageToken, changes(fileId, removed, file(id, name, mimeType, modifiedTime, size, webViewLink, parents, trashed))",
|
||||
"fields": "nextPageToken, newStartPageToken, changes(fileId, removed, file(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, trashed))",
|
||||
"supportsAllDrives": True,
|
||||
"includeItemsFromAllDrives": True,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class GoogleDriveClient:
|
|||
async def list_files(
|
||||
self,
|
||||
query: str = "",
|
||||
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, size, webViewLink, parents, owners, createdTime, description)",
|
||||
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
||||
page_size: int = 100,
|
||||
page_token: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ async def download_and_process_file(
|
|||
connector_info["metadata"]["file_size"] = file["size"]
|
||||
if "webViewLink" in file:
|
||||
connector_info["metadata"]["web_view_link"] = file["webViewLink"]
|
||||
if "md5Checksum" in file:
|
||||
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
|
||||
|
||||
if is_google_workspace_file(mime_type):
|
||||
connector_info["metadata"]["exported_as"] = "pdf"
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ async def get_file_by_id(
|
|||
try:
|
||||
file, error = await client.get_file_metadata(
|
||||
file_id,
|
||||
fields="id, name, mimeType, parents, createdTime, modifiedTime, size, webViewLink, iconLink",
|
||||
fields="id, name, mimeType, parents, createdTime, modifiedTime, md5Checksum, size, webViewLink, iconLink",
|
||||
)
|
||||
|
||||
if error:
|
||||
|
|
@ -228,7 +228,7 @@ async def list_folder_contents(
|
|||
while True:
|
||||
items, next_token, error = await client.list_files(
|
||||
query=query,
|
||||
fields="files(id, name, mimeType, parents, createdTime, modifiedTime, size, webViewLink, iconLink)",
|
||||
fields="files(id, name, mimeType, parents, createdTime, modifiedTime, md5Checksum, size, webViewLink, iconLink)",
|
||||
page_size=1000, # Max allowed by Google Drive API
|
||||
page_token=page_token,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -152,6 +152,11 @@ class Permission(str, Enum):
|
|||
CHATS_UPDATE = "chats:update"
|
||||
CHATS_DELETE = "chats:delete"
|
||||
|
||||
# Comments
|
||||
COMMENTS_CREATE = "comments:create"
|
||||
COMMENTS_READ = "comments:read"
|
||||
COMMENTS_DELETE = "comments:delete"
|
||||
|
||||
# LLM Configs
|
||||
LLM_CONFIGS_CREATE = "llm_configs:create"
|
||||
LLM_CONFIGS_READ = "llm_configs:read"
|
||||
|
|
@ -209,6 +214,10 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.CHATS_READ.value,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
Permission.CHATS_DELETE.value,
|
||||
# Comments
|
||||
Permission.COMMENTS_CREATE.value,
|
||||
Permission.COMMENTS_READ.value,
|
||||
Permission.COMMENTS_DELETE.value,
|
||||
# LLM Configs
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
|
|
@ -252,6 +261,9 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.CHATS_READ.value,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
Permission.CHATS_DELETE.value,
|
||||
# Comments (no delete)
|
||||
Permission.COMMENTS_CREATE.value,
|
||||
Permission.COMMENTS_READ.value,
|
||||
# LLM Configs (read only)
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
|
|
@ -279,6 +291,9 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.DOCUMENTS_READ.value,
|
||||
# Chats (read only)
|
||||
Permission.CHATS_READ.value,
|
||||
# Comments (no delete)
|
||||
Permission.COMMENTS_CREATE.value,
|
||||
Permission.COMMENTS_READ.value,
|
||||
# LLM Configs (read only)
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
# Podcasts (read only)
|
||||
|
|
@ -424,6 +439,84 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
# Relationships
|
||||
thread = relationship("NewChatThread", back_populates="messages")
|
||||
author = relationship("User")
|
||||
comments = relationship(
|
||||
"ChatComment",
|
||||
back_populates="message",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class ChatComment(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Comment model for comments on AI chat responses.
|
||||
Supports one level of nesting (replies to comments, but no replies to replies).
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_comments"
|
||||
|
||||
message_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_messages.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
parent_id = Column(
|
||||
Integer,
|
||||
ForeignKey("chat_comments.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
author_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
content = Column(Text, nullable=False)
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
message = relationship("NewChatMessage", back_populates="comments")
|
||||
author = relationship("User")
|
||||
parent = relationship(
|
||||
"ChatComment", remote_side="ChatComment.id", backref="replies"
|
||||
)
|
||||
mentions = relationship(
|
||||
"ChatCommentMention",
|
||||
back_populates="comment",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class ChatCommentMention(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Tracks @mentions in chat comments for notification purposes.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_comment_mentions"
|
||||
|
||||
comment_id = Column(
|
||||
Integer,
|
||||
ForeignKey("chat_comments.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
mentioned_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
comment = relationship("ChatComment", back_populates="mentions")
|
||||
mentioned_user = relationship("User")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
|
|
@ -574,6 +667,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Log.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
notifications = relationship(
|
||||
"Notification",
|
||||
back_populates="search_space",
|
||||
order_by="Notification.created_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
search_source_connectors = relationship(
|
||||
"SearchSourceConnector",
|
||||
back_populates="search_space",
|
||||
|
|
@ -712,6 +811,39 @@ class Log(BaseModel, TimestampMixin):
|
|||
search_space = relationship("SearchSpace", back_populates="logs")
|
||||
|
||||
|
||||
class Notification(BaseModel, TimestampMixin):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
type = Column(
|
||||
String(50), nullable=False
|
||||
) # 'connector_indexing', 'document_processing', etc.
|
||||
title = Column(String(200), nullable=False)
|
||||
message = Column(Text, nullable=False)
|
||||
read = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false"), index=True
|
||||
)
|
||||
notification_metadata = Column("metadata", JSONB, nullable=True, default={})
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
user = relationship("User", back_populates="notifications")
|
||||
search_space = relationship("SearchSpace", back_populates="notifications")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Custom roles that can be defined per search space.
|
||||
|
|
@ -856,6 +988,12 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
notifications = relationship(
|
||||
"Notification",
|
||||
back_populates="user",
|
||||
order_by="Notification.created_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
search_space_memberships = relationship(
|
||||
|
|
@ -893,6 +1031,12 @@ else:
|
|||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
notifications = relationship(
|
||||
"Notification",
|
||||
back_populates="user",
|
||||
order_by="Notification.created_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# RBAC relationships
|
||||
search_space_memberships = relationship(
|
||||
|
|
@ -956,11 +1100,36 @@ async def setup_indexes():
|
|||
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
# pg_trgm indexes for efficient ILIKE '%term%' searches on titles
|
||||
# Critical for document mention picker (@mentions) to scale
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)"
|
||||
)
|
||||
)
|
||||
# B-tree index on search_space_id for fast filtering
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)"
|
||||
)
|
||||
)
|
||||
# Covering index for "recent documents" query - enables index-only scan
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_surfsense_docs_title_trgm ON surfsense_docs_documents USING gin (title gin_trgm_ops)"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await setup_indexes()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi import APIRouter
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
from .confluence_add_connector_route import router as confluence_add_connector_router
|
||||
|
|
@ -25,6 +26,7 @@ from .luma_add_connector_route import router as luma_add_connector_router
|
|||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .notifications_routes import router as notifications_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
|
|
@ -42,6 +44,7 @@ router.include_router(editor_router)
|
|||
router.include_router(documents_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
|
|
@ -61,3 +64,4 @@ router.include_router(new_llm_config_router) # LLM configs with prompt configur
|
|||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Electric SQL sync
|
||||
|
|
|
|||
95
surfsense_backend/app/routes/chat_comments_routes.py
Normal file
95
surfsense_backend/app/routes/chat_comments_routes.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Routes for chat comments and mentions.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentCreateRequest,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
CommentResponse,
|
||||
CommentUpdateRequest,
|
||||
MentionListResponse,
|
||||
)
|
||||
from app.services.chat_comments_service import (
|
||||
create_comment,
|
||||
create_reply,
|
||||
delete_comment,
|
||||
get_comments_for_message,
|
||||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List all comments for a message with their replies."""
|
||||
return await get_comments_for_message(session, message_id, user)
|
||||
|
||||
|
||||
@router.post("/messages/{message_id}/comments", response_model=CommentResponse)
|
||||
async def add_comment(
|
||||
message_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a top-level comment on an AI response."""
|
||||
return await create_comment(session, message_id, request.content, user)
|
||||
|
||||
|
||||
@router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse)
|
||||
async def add_reply(
|
||||
comment_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Reply to an existing comment."""
|
||||
return await create_reply(session, comment_id, request.content, user)
|
||||
|
||||
|
||||
@router.put("/comments/{comment_id}", response_model=CommentReplyResponse)
|
||||
async def edit_comment(
|
||||
comment_id: int,
|
||||
request: CommentUpdateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update a comment's content (author only)."""
|
||||
return await update_comment(session, comment_id, request.content, user)
|
||||
|
||||
|
||||
@router.delete("/comments/{comment_id}")
|
||||
async def remove_comment(
|
||||
comment_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a comment (author or user with COMMENTS_DELETE permission)."""
|
||||
return await delete_comment(session, comment_id, user)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mention Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/mentions", response_model=MentionListResponse)
|
||||
async def list_mentions(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List mentions for the current user."""
|
||||
return await get_user_mentions(session, user, search_space_id)
|
||||
|
|
@ -19,6 +19,8 @@ from app.db import (
|
|||
from app.schemas import (
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentTitleRead,
|
||||
DocumentTitleSearchResponse,
|
||||
DocumentUpdate,
|
||||
DocumentWithChunksRead,
|
||||
PaginatedResponse,
|
||||
|
|
@ -429,6 +431,112 @@ async def search_documents(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/search/titles", response_model=DocumentTitleSearchResponse)
|
||||
async def search_document_titles(
|
||||
search_space_id: int,
|
||||
title: str = "",
|
||||
page: int = 0,
|
||||
page_size: int = 20,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Lightweight document title search optimized for mention picker (@mentions).
|
||||
|
||||
Returns only id, title, and document_type - no content or metadata.
|
||||
Uses pg_trgm fuzzy search with similarity scoring for typo tolerance.
|
||||
Results are ordered by relevance using trigram similarity scores.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space to search in. Required.
|
||||
title: Search query (case-insensitive). If empty or < 2 chars, returns recent documents.
|
||||
page: Zero-based page index. Default: 0.
|
||||
page_size: Number of items per page. Default: 20.
|
||||
session: Database session (injected).
|
||||
user: Current authenticated user (injected).
|
||||
|
||||
Returns:
|
||||
DocumentTitleSearchResponse: Lightweight list with has_more flag (no total count).
|
||||
"""
|
||||
from sqlalchemy import desc, func, or_
|
||||
|
||||
try:
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
# Base query - only select lightweight fields
|
||||
query = select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.document_type,
|
||||
).filter(Document.search_space_id == search_space_id)
|
||||
|
||||
# If query is too short, return recent documents ordered by updated_at
|
||||
if len(title.strip()) < 2:
|
||||
query = query.order_by(Document.updated_at.desc().nullslast())
|
||||
else:
|
||||
# Fuzzy search using pg_trgm similarity + ILIKE fallback
|
||||
search_term = title.strip()
|
||||
|
||||
# Similarity threshold for fuzzy matching (0.3 = ~30% trigram overlap)
|
||||
# Lower values = more fuzzy, higher values = stricter matching
|
||||
similarity_threshold = 0.3
|
||||
|
||||
# Match documents that either:
|
||||
# 1. Have high trigram similarity (fuzzy match - handles typos)
|
||||
# 2. Contain the exact substring (ILIKE - handles partial matches)
|
||||
query = query.filter(
|
||||
or_(
|
||||
func.similarity(Document.title, search_term) > similarity_threshold,
|
||||
Document.title.ilike(f"%{search_term}%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Order by similarity score (descending) for best relevance ranking
|
||||
# Higher similarity = better match = appears first
|
||||
query = query.order_by(
|
||||
desc(func.similarity(Document.title, search_term)),
|
||||
Document.title, # Alphabetical tiebreaker
|
||||
)
|
||||
|
||||
# Fetch page_size + 1 to determine has_more without COUNT query
|
||||
offset = page * page_size
|
||||
result = await session.execute(query.offset(offset).limit(page_size + 1))
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > page_size
|
||||
items = rows[:page_size] # Only return requested page_size
|
||||
|
||||
# Convert to response format
|
||||
api_documents = [
|
||||
DocumentTitleRead(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
document_type=row.document_type,
|
||||
)
|
||||
for row in items
|
||||
]
|
||||
|
||||
return DocumentTitleSearchResponse(
|
||||
items=api_documents,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to search document titles: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/type-counts")
|
||||
async def get_document_type_counts(
|
||||
search_space_id: int | None = None,
|
||||
|
|
|
|||
|
|
@ -19,13 +19,14 @@ from datetime import UTC, datetime
|
|||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
|
|
@ -508,7 +509,19 @@ async def get_thread_full(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
return thread
|
||||
# Check if thread has any comments
|
||||
comment_count = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(ChatComment)
|
||||
.join(NewChatMessage, ChatComment.message_id == NewChatMessage.id)
|
||||
.where(NewChatMessage.thread_id == thread.id)
|
||||
)
|
||||
|
||||
return {
|
||||
**thread.__dict__,
|
||||
"messages": thread.messages,
|
||||
"has_comments": (comment_count or 0) > 0,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
102
surfsense_backend/app/routes/notifications_routes.py
Normal file
102
surfsense_backend/app/routes/notifications_routes.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Notifications API routes.
|
||||
These endpoints allow marking notifications as read.
|
||||
Electric SQL automatically syncs the changes to all connected clients.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Notification, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
class MarkReadResponse(BaseModel):
|
||||
"""Response for mark as read operations."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class MarkAllReadResponse(BaseModel):
|
||||
"""Response for mark all as read operation."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
updated_count: int
|
||||
|
||||
|
||||
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
|
||||
async def mark_notification_as_read(
|
||||
notification_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkReadResponse:
|
||||
"""
|
||||
Mark a single notification as read.
|
||||
|
||||
Electric SQL will automatically sync this change to all connected clients.
|
||||
"""
|
||||
# Verify the notification belongs to the user
|
||||
result = await session.execute(
|
||||
select(Notification).where(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user.id,
|
||||
)
|
||||
)
|
||||
notification = result.scalar_one_or_none()
|
||||
|
||||
if not notification:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found",
|
||||
)
|
||||
|
||||
if notification.read:
|
||||
return MarkReadResponse(
|
||||
success=True,
|
||||
message="Notification already marked as read",
|
||||
)
|
||||
|
||||
# Update the notification
|
||||
notification.read = True
|
||||
await session.commit()
|
||||
|
||||
return MarkReadResponse(
|
||||
success=True,
|
||||
message="Notification marked as read",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/read-all", response_model=MarkAllReadResponse)
|
||||
async def mark_all_notifications_as_read(
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkAllReadResponse:
|
||||
"""
|
||||
Mark all notifications as read for the current user.
|
||||
|
||||
Electric SQL will automatically sync these changes to all connected clients.
|
||||
"""
|
||||
# Update all unread notifications for the user
|
||||
result = await session.execute(
|
||||
update(Notification)
|
||||
.where(
|
||||
Notification.user_id == user.id,
|
||||
Notification.read == False, # noqa: E712
|
||||
)
|
||||
.values(read=True)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
return MarkAllReadResponse(
|
||||
success=True,
|
||||
message=f"Marked {updated_count} notification(s) as read",
|
||||
updated_count=updated_count,
|
||||
)
|
||||
|
|
@ -452,6 +452,8 @@ async def list_members(
|
|||
"created_at": membership.created_at,
|
||||
"role": membership.role,
|
||||
"user_email": member_user.email if member_user else None,
|
||||
"user_display_name": member_user.display_name if member_user else None,
|
||||
"user_avatar_url": member_user.avatar_url if member_user else None,
|
||||
}
|
||||
response.append(membership_dict)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,13 +4,15 @@ from .documents import (
|
|||
DocumentBase,
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentTitleRead,
|
||||
DocumentTitleSearchResponse,
|
||||
DocumentUpdate,
|
||||
DocumentWithChunksRead,
|
||||
ExtensionDocumentContent,
|
||||
ExtensionDocumentMetadata,
|
||||
PaginatedResponse,
|
||||
)
|
||||
from .google_drive import DriveItem, GoogleDriveIndexRequest
|
||||
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .new_chat import (
|
||||
ChatMessage,
|
||||
|
|
@ -85,6 +87,8 @@ __all__ = [
|
|||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentTitleRead",
|
||||
"DocumentTitleSearchResponse",
|
||||
"DocumentUpdate",
|
||||
"DocumentWithChunksRead",
|
||||
"DocumentsCreate",
|
||||
|
|
@ -94,6 +98,7 @@ __all__ = [
|
|||
"ExtensionDocumentMetadata",
|
||||
"GlobalNewLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# RBAC schemas
|
||||
|
|
|
|||
129
surfsense_backend/app/schemas/chat_comments.py
Normal file
129
surfsense_backend/app/schemas/chat_comments.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Pydantic schemas for chat comments and mentions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# =============================================================================
|
||||
# Request Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CommentCreateRequest(BaseModel):
|
||||
"""Schema for creating a comment or reply."""
|
||||
|
||||
content: str = Field(..., min_length=1, max_length=5000)
|
||||
|
||||
|
||||
class CommentUpdateRequest(BaseModel):
|
||||
"""Schema for updating a comment."""
|
||||
|
||||
content: str = Field(..., min_length=1, max_length=5000)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Author Schema
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AuthorResponse(BaseModel):
|
||||
"""Author information for comments."""
|
||||
|
||||
id: UUID
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
email: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Comment Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CommentReplyResponse(BaseModel):
|
||||
"""Schema for a comment reply (no nested replies)."""
|
||||
|
||||
id: int
|
||||
content: str
|
||||
content_rendered: str
|
||||
author: AuthorResponse | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_edited: bool
|
||||
can_edit: bool = False
|
||||
can_delete: bool = False
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CommentResponse(BaseModel):
|
||||
"""Schema for a top-level comment with replies."""
|
||||
|
||||
id: int
|
||||
message_id: int
|
||||
content: str
|
||||
content_rendered: str
|
||||
author: AuthorResponse | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_edited: bool
|
||||
can_edit: bool = False
|
||||
can_delete: bool = False
|
||||
reply_count: int
|
||||
replies: list[CommentReplyResponse] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CommentListResponse(BaseModel):
|
||||
"""Response for listing comments on a message."""
|
||||
|
||||
comments: list[CommentResponse]
|
||||
total_count: int
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mention Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MentionContextResponse(BaseModel):
|
||||
"""Context information for where a mention occurred."""
|
||||
|
||||
thread_id: int
|
||||
thread_title: str
|
||||
message_id: int
|
||||
search_space_id: int
|
||||
search_space_name: str
|
||||
|
||||
|
||||
class MentionCommentResponse(BaseModel):
|
||||
"""Abbreviated comment info for mention display."""
|
||||
|
||||
id: int
|
||||
content_preview: str
|
||||
author: AuthorResponse | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class MentionResponse(BaseModel):
|
||||
"""Schema for a mention notification."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
comment: MentionCommentResponse
|
||||
context: MentionContextResponse
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MentionListResponse(BaseModel):
|
||||
"""Response for listing user's mentions."""
|
||||
|
||||
mentions: list[MentionResponse]
|
||||
total_count: int
|
||||
|
|
@ -67,3 +67,20 @@ class PaginatedResponse[T](BaseModel):
|
|||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class DocumentTitleRead(BaseModel):
|
||||
"""Lightweight document response for mention picker - only essential fields."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
document_type: DocumentType
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DocumentTitleSearchResponse(BaseModel):
|
||||
"""Response for document title search - optimized for typeahead."""
|
||||
|
||||
items: list[DocumentTitleRead]
|
||||
has_more: bool
|
||||
|
|
|
|||
|
|
@ -10,6 +10,25 @@ class DriveItem(BaseModel):
|
|||
name: str = Field(..., description="Item display name")
|
||||
|
||||
|
||||
class GoogleDriveIndexingOptions(BaseModel):
|
||||
"""Indexing options for Google Drive connector."""
|
||||
|
||||
max_files_per_folder: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Maximum number of files to index from each folder (1-1000)",
|
||||
)
|
||||
incremental_sync: bool = Field(
|
||||
default=True,
|
||||
description="Only sync changes since last index (faster). Disable for a full re-index.",
|
||||
)
|
||||
include_subfolders: bool = Field(
|
||||
default=True,
|
||||
description="Recursively index files in subfolders of selected folders",
|
||||
)
|
||||
|
||||
|
||||
class GoogleDriveIndexRequest(BaseModel):
|
||||
"""Request body for indexing Google Drive content."""
|
||||
|
||||
|
|
@ -19,6 +38,10 @@ class GoogleDriveIndexRequest(BaseModel):
|
|||
files: list[DriveItem] = Field(
|
||||
default_factory=list, description="List of specific files to index"
|
||||
)
|
||||
indexing_options: GoogleDriveIndexingOptions = Field(
|
||||
default_factory=GoogleDriveIndexingOptions,
|
||||
description="Indexing configuration options",
|
||||
)
|
||||
|
||||
def has_items(self) -> bool:
|
||||
"""Check if any items are selected."""
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ class NewChatThreadWithMessages(NewChatThreadRead):
|
|||
"""Schema for reading a thread with its messages."""
|
||||
|
||||
messages: list[NewChatMessageRead] = []
|
||||
has_comments: bool = False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -73,8 +73,10 @@ class MembershipRead(BaseModel):
|
|||
created_at: datetime
|
||||
# Nested role info
|
||||
role: RoleRead | None = None
|
||||
# User email (populated separately)
|
||||
# User details (populated separately)
|
||||
user_email: str | None = None
|
||||
user_display_name: str | None = None
|
||||
user_avatar_url: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
|
|
|||
|
|
@ -83,19 +83,34 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod
|
|||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for an MCP server connection (similar to Cursor's config)."""
|
||||
"""Configuration for an MCP server connection.
|
||||
|
||||
command: str # e.g., "uvx", "node", "python"
|
||||
Supports two transport types:
|
||||
- stdio: Local process (command, args, env)
|
||||
- streamable-http/http/sse: Remote HTTP server (url, headers)
|
||||
"""
|
||||
|
||||
# stdio transport fields
|
||||
command: str | None = None # e.g., "uvx", "node", "python"
|
||||
args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"]
|
||||
env: dict[str, str] = {} # Environment variables for the server process
|
||||
transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common)
|
||||
|
||||
# HTTP transport fields
|
||||
url: str | None = None # e.g., "https://mcp-server.com/mcp"
|
||||
headers: dict[str, str] = {} # HTTP headers for authentication
|
||||
|
||||
transport: str = "stdio" # "stdio" | "streamable-http" | "http" | "sse"
|
||||
|
||||
def is_http_transport(self) -> bool:
|
||||
"""Check if this config uses HTTP transport."""
|
||||
return self.transport in ("streamable-http", "http", "sse")
|
||||
|
||||
|
||||
class MCPConnectorCreate(BaseModel):
|
||||
"""Schema for creating an MCP connector."""
|
||||
|
||||
name: str
|
||||
server_config: MCPServerConfig
|
||||
server_config: MCPServerConfig # Single MCP server configuration
|
||||
|
||||
|
||||
class MCPConnectorUpdate(BaseModel):
|
||||
|
|
@ -106,7 +121,7 @@ class MCPConnectorUpdate(BaseModel):
|
|||
|
||||
|
||||
class MCPConnectorRead(BaseModel):
|
||||
"""Schema for reading an MCP connector with server config."""
|
||||
"""Schema for reading an MCP connector with server configs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
|
@ -123,7 +138,8 @@ class MCPConnectorRead(BaseModel):
|
|||
def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead":
|
||||
"""Convert from base SearchSourceConnectorRead."""
|
||||
config = connector.config or {}
|
||||
server_config = MCPServerConfig(**config.get("server_config", {}))
|
||||
server_config_data = config.get("server_config", {})
|
||||
server_config = MCPServerConfig(**server_config_data)
|
||||
|
||||
return cls(
|
||||
id=connector.id,
|
||||
|
|
|
|||
733
surfsense_backend/app/services/chat_comments_service.py
Normal file
733
surfsense_backend/app/services/chat_comments_service.py
Normal file
|
|
@ -0,0 +1,733 @@
|
|||
"""
|
||||
Service layer for chat comments and mentions.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatCommentMention,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
has_permission,
|
||||
)
|
||||
from app.schemas.chat_comments import (
|
||||
AuthorResponse,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
CommentResponse,
|
||||
MentionCommentResponse,
|
||||
MentionContextResponse,
|
||||
MentionListResponse,
|
||||
MentionResponse,
|
||||
)
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.utils.chat_comments import parse_mentions, render_mentions
|
||||
from app.utils.rbac import check_permission, get_user_permissions
|
||||
|
||||
|
||||
async def get_user_names_for_mentions(
|
||||
session: AsyncSession,
|
||||
user_ids: set[UUID],
|
||||
) -> dict[UUID, str]:
|
||||
"""
|
||||
Fetch display names for a set of user IDs.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_ids: Set of user UUIDs to look up
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user UUID to display name
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
result = await session.execute(
|
||||
select(User.id, User.display_name).filter(User.id.in_(user_ids))
|
||||
)
|
||||
return {row.id: row.display_name or "Unknown" for row in result.all()}
|
||||
|
||||
|
||||
async def process_mentions(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
search_space_id: int,
|
||||
) -> dict[UUID, int]:
|
||||
"""
|
||||
Parse mentions from content, validate users are members, and insert mention records.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
comment_id: ID of the comment containing mentions
|
||||
content: Comment text with @[uuid] mentions
|
||||
search_space_id: ID of the search space for membership validation
|
||||
|
||||
Returns:
|
||||
Dictionary mapping mentioned user UUID to their mention record ID
|
||||
"""
|
||||
mentioned_uuids = parse_mentions(content)
|
||||
if not mentioned_uuids:
|
||||
return {}
|
||||
|
||||
# Get valid members from the mentioned UUIDs
|
||||
result = await session.execute(
|
||||
select(SearchSpaceMembership.user_id).filter(
|
||||
SearchSpaceMembership.search_space_id == search_space_id,
|
||||
SearchSpaceMembership.user_id.in_(mentioned_uuids),
|
||||
)
|
||||
)
|
||||
valid_member_ids = result.scalars().all()
|
||||
|
||||
# Insert mention records for valid members and collect their IDs
|
||||
mentions_map: dict[UUID, int] = {}
|
||||
for user_id in valid_member_ids:
|
||||
mention = ChatCommentMention(
|
||||
comment_id=comment_id,
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
await session.flush()
|
||||
mentions_map[user_id] = mention.id
|
||||
|
||||
return mentions_map
|
||||
|
||||
|
||||
async def get_comments_for_message(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
user: User,
|
||||
) -> CommentListResponse:
|
||||
"""
|
||||
Get all comments for a message with their replies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
message_id: ID of the message to get comments for
|
||||
user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
CommentListResponse with all top-level comments and their replies
|
||||
|
||||
Raises:
|
||||
HTTPException: If message not found or user lacks COMMENTS_READ permission
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.thread))
|
||||
.filter(NewChatMessage.id == message_id)
|
||||
)
|
||||
message = result.scalars().first()
|
||||
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
search_space_id = message.thread.search_space_id
|
||||
|
||||
# Check permission to read comments
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
)
|
||||
|
||||
# Get user permissions for can_delete computation
|
||||
user_permissions = await get_user_permissions(session, user.id, search_space_id)
|
||||
can_delete_any = has_permission(user_permissions, Permission.COMMENTS_DELETE.value)
|
||||
|
||||
# Get top-level comments (parent_id IS NULL) with their authors and replies
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(
|
||||
selectinload(ChatComment.author),
|
||||
selectinload(ChatComment.replies).selectinload(ChatComment.author),
|
||||
)
|
||||
.filter(
|
||||
ChatComment.message_id == message_id,
|
||||
ChatComment.parent_id.is_(None),
|
||||
)
|
||||
.order_by(ChatComment.created_at)
|
||||
)
|
||||
top_level_comments = result.scalars().all()
|
||||
|
||||
# Collect all mentioned UUIDs from comments and replies for rendering
|
||||
all_mentioned_uuids: set[UUID] = set()
|
||||
for comment in top_level_comments:
|
||||
all_mentioned_uuids.update(parse_mentions(comment.content))
|
||||
for reply in comment.replies:
|
||||
all_mentioned_uuids.update(parse_mentions(reply.content))
|
||||
|
||||
# Fetch display names for mentioned users
|
||||
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
|
||||
|
||||
comments = []
|
||||
for comment in top_level_comments:
|
||||
author = None
|
||||
if comment.author:
|
||||
author = AuthorResponse(
|
||||
id=comment.author.id,
|
||||
display_name=comment.author.display_name,
|
||||
avatar_url=comment.author.avatar_url,
|
||||
email=comment.author.email,
|
||||
)
|
||||
|
||||
replies = []
|
||||
for reply in sorted(comment.replies, key=lambda r: r.created_at):
|
||||
reply_author = None
|
||||
if reply.author:
|
||||
reply_author = AuthorResponse(
|
||||
id=reply.author.id,
|
||||
display_name=reply.author.display_name,
|
||||
avatar_url=reply.author.avatar_url,
|
||||
email=reply.author.email,
|
||||
)
|
||||
|
||||
is_reply_author = reply.author_id == user.id if reply.author_id else False
|
||||
replies.append(
|
||||
CommentReplyResponse(
|
||||
id=reply.id,
|
||||
content=reply.content,
|
||||
content_rendered=render_mentions(reply.content, user_names),
|
||||
author=reply_author,
|
||||
created_at=reply.created_at,
|
||||
updated_at=reply.updated_at,
|
||||
is_edited=reply.updated_at > reply.created_at,
|
||||
can_edit=is_reply_author,
|
||||
can_delete=is_reply_author or can_delete_any,
|
||||
)
|
||||
)
|
||||
|
||||
is_comment_author = comment.author_id == user.id if comment.author_id else False
|
||||
comments.append(
|
||||
CommentResponse(
|
||||
id=comment.id,
|
||||
message_id=comment.message_id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(comment.content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=comment.updated_at > comment.created_at,
|
||||
can_edit=is_comment_author,
|
||||
can_delete=is_comment_author or can_delete_any,
|
||||
reply_count=len(replies),
|
||||
replies=replies,
|
||||
)
|
||||
)
|
||||
|
||||
return CommentListResponse(
|
||||
comments=comments,
|
||||
total_count=len(comments),
|
||||
)
|
||||
|
||||
|
||||
async def create_comment(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
) -> CommentResponse:
|
||||
"""
|
||||
Create a top-level comment on an AI response.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
message_id: ID of the message to comment on
|
||||
content: Comment text content
|
||||
user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
CommentResponse for the created comment
|
||||
|
||||
Raises:
|
||||
HTTPException: If message not found, not AI response, or user lacks COMMENTS_CREATE permission
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.thread))
|
||||
.filter(NewChatMessage.id == message_id)
|
||||
)
|
||||
message = result.scalars().first()
|
||||
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Validate message is an AI response
|
||||
if message.role != NewChatMessageRole.ASSISTANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Comments can only be added to AI responses",
|
||||
)
|
||||
|
||||
search_space_id = message.thread.search_space_id
|
||||
|
||||
# Check permission to create comments
|
||||
user_permissions = await get_user_permissions(session, user.id, search_space_id)
|
||||
if not has_permission(user_permissions, Permission.COMMENTS_CREATE.value):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have permission to create comments in this search space",
|
||||
)
|
||||
|
||||
comment = ChatComment(
|
||||
message_id=message_id,
|
||||
author_id=user.id,
|
||||
content=content,
|
||||
)
|
||||
session.add(comment)
|
||||
await session.flush()
|
||||
|
||||
# Process mentions - returns map of user_id -> mention_id
|
||||
mentions_map = await process_mentions(session, comment.id, content, search_space_id)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(comment)
|
||||
|
||||
# Fetch user names for rendering mentions (reuse mentions_map keys)
|
||||
user_names = await get_user_names_for_mentions(session, set(mentions_map.keys()))
|
||||
|
||||
# Create notifications for mentioned users (excluding author)
|
||||
thread = message.thread
|
||||
author_name = user.display_name or user.email
|
||||
content_preview = render_mentions(content, user_names)
|
||||
for mentioned_user_id, mention_id in mentions_map.items():
|
||||
if mentioned_user_id == user.id:
|
||||
continue # Don't notify yourself
|
||||
await NotificationService.mention.notify_new_mention(
|
||||
session=session,
|
||||
mentioned_user_id=mentioned_user_id,
|
||||
mention_id=mention_id,
|
||||
comment_id=comment.id,
|
||||
message_id=message_id,
|
||||
thread_id=thread.id,
|
||||
thread_title=thread.title or "Untitled thread",
|
||||
author_id=str(user.id),
|
||||
author_name=author_name,
|
||||
content_preview=content_preview[:200],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
author = AuthorResponse(
|
||||
id=user.id,
|
||||
display_name=user.display_name,
|
||||
avatar_url=user.avatar_url,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
return CommentResponse(
|
||||
id=comment.id,
|
||||
message_id=comment.message_id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=False,
|
||||
can_edit=True,
|
||||
can_delete=True,
|
||||
reply_count=0,
|
||||
replies=[],
|
||||
)
|
||||
|
||||
|
||||
async def create_reply(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
) -> CommentReplyResponse:
|
||||
"""
|
||||
Create a reply to an existing comment.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
comment_id: ID of the parent comment to reply to
|
||||
content: Reply text content
|
||||
user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
CommentReplyResponse for the created reply
|
||||
|
||||
Raises:
|
||||
HTTPException: If comment not found, is already a reply, or user lacks COMMENTS_CREATE permission
|
||||
"""
|
||||
# Get parent comment with its message and thread
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(selectinload(ChatComment.message).selectinload(NewChatMessage.thread))
|
||||
.filter(ChatComment.id == comment_id)
|
||||
)
|
||||
parent_comment = result.scalars().first()
|
||||
|
||||
if not parent_comment:
|
||||
raise HTTPException(status_code=404, detail="Comment not found")
|
||||
|
||||
# Validate parent is a top-level comment (cannot reply to a reply)
|
||||
if parent_comment.parent_id is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot reply to a reply",
|
||||
)
|
||||
|
||||
search_space_id = parent_comment.message.thread.search_space_id
|
||||
|
||||
# Check permission to create comments
|
||||
user_permissions = await get_user_permissions(session, user.id, search_space_id)
|
||||
if not has_permission(user_permissions, Permission.COMMENTS_CREATE.value):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have permission to create comments in this search space",
|
||||
)
|
||||
|
||||
reply = ChatComment(
|
||||
message_id=parent_comment.message_id,
|
||||
parent_id=comment_id,
|
||||
author_id=user.id,
|
||||
content=content,
|
||||
)
|
||||
session.add(reply)
|
||||
await session.flush()
|
||||
|
||||
# Process mentions - returns map of user_id -> mention_id
|
||||
mentions_map = await process_mentions(session, reply.id, content, search_space_id)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(reply)
|
||||
|
||||
# Fetch user names for rendering mentions (reuse mentions_map keys)
|
||||
user_names = await get_user_names_for_mentions(session, set(mentions_map.keys()))
|
||||
|
||||
# Create notifications for mentioned users (excluding author)
|
||||
thread = parent_comment.message.thread
|
||||
author_name = user.display_name or user.email
|
||||
content_preview = render_mentions(content, user_names)
|
||||
for mentioned_user_id, mention_id in mentions_map.items():
|
||||
if mentioned_user_id == user.id:
|
||||
continue # Don't notify yourself
|
||||
await NotificationService.mention.notify_new_mention(
|
||||
session=session,
|
||||
mentioned_user_id=mentioned_user_id,
|
||||
mention_id=mention_id,
|
||||
comment_id=reply.id,
|
||||
message_id=parent_comment.message_id,
|
||||
thread_id=thread.id,
|
||||
thread_title=thread.title or "Untitled thread",
|
||||
author_id=str(user.id),
|
||||
author_name=author_name,
|
||||
content_preview=content_preview[:200],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
author = AuthorResponse(
|
||||
id=user.id,
|
||||
display_name=user.display_name,
|
||||
avatar_url=user.avatar_url,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
return CommentReplyResponse(
|
||||
id=reply.id,
|
||||
content=reply.content,
|
||||
content_rendered=render_mentions(content, user_names),
|
||||
author=author,
|
||||
created_at=reply.created_at,
|
||||
updated_at=reply.updated_at,
|
||||
is_edited=False,
|
||||
can_edit=True,
|
||||
can_delete=True,
|
||||
)
|
||||
|
||||
|
||||
async def update_comment(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
) -> CommentReplyResponse:
|
||||
"""
|
||||
Update a comment's content (author only).
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
comment_id: ID of the comment to update
|
||||
content: New comment text content
|
||||
user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
CommentReplyResponse for the updated comment
|
||||
|
||||
Raises:
|
||||
HTTPException: If comment not found or user is not the author
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(
|
||||
selectinload(ChatComment.author),
|
||||
selectinload(ChatComment.message).selectinload(NewChatMessage.thread),
|
||||
)
|
||||
.filter(ChatComment.id == comment_id)
|
||||
)
|
||||
comment = result.scalars().first()
|
||||
|
||||
if not comment:
|
||||
raise HTTPException(status_code=404, detail="Comment not found")
|
||||
|
||||
if comment.author_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You can only edit your own comments",
|
||||
)
|
||||
|
||||
search_space_id = comment.message.thread.search_space_id
|
||||
|
||||
# Get existing mentioned user IDs
|
||||
existing_result = await session.execute(
|
||||
select(ChatCommentMention.mentioned_user_id).filter(
|
||||
ChatCommentMention.comment_id == comment_id
|
||||
)
|
||||
)
|
||||
existing_mention_ids = set(existing_result.scalars().all())
|
||||
|
||||
# Parse new mentions from updated content
|
||||
new_mention_uuids = set(parse_mentions(content))
|
||||
|
||||
# Validate new mentions are search space members
|
||||
if new_mention_uuids:
|
||||
valid_result = await session.execute(
|
||||
select(SearchSpaceMembership.user_id).filter(
|
||||
SearchSpaceMembership.search_space_id == search_space_id,
|
||||
SearchSpaceMembership.user_id.in_(new_mention_uuids),
|
||||
)
|
||||
)
|
||||
valid_new_mentions = set(valid_result.scalars().all())
|
||||
else:
|
||||
valid_new_mentions = set()
|
||||
|
||||
# Compute diff: removed, kept (preserve read status), added
|
||||
mentions_to_remove = existing_mention_ids - valid_new_mentions
|
||||
mentions_to_add = valid_new_mentions - existing_mention_ids
|
||||
|
||||
# Delete removed mentions
|
||||
if mentions_to_remove:
|
||||
await session.execute(
|
||||
delete(ChatCommentMention).where(
|
||||
ChatCommentMention.comment_id == comment_id,
|
||||
ChatCommentMention.mentioned_user_id.in_(mentions_to_remove),
|
||||
)
|
||||
)
|
||||
|
||||
# Add new mentions and collect their IDs for notifications
|
||||
new_mentions_map: dict[UUID, int] = {}
|
||||
for user_id in mentions_to_add:
|
||||
mention = ChatCommentMention(
|
||||
comment_id=comment_id,
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
await session.flush()
|
||||
new_mentions_map[user_id] = mention.id
|
||||
|
||||
comment.content = content
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(comment)
|
||||
|
||||
# Fetch user names for rendering mentions
|
||||
user_names = await get_user_names_for_mentions(session, valid_new_mentions)
|
||||
|
||||
# Create notifications for newly added mentions (excluding author)
|
||||
if new_mentions_map:
|
||||
thread = comment.message.thread
|
||||
author_name = user.display_name or user.email
|
||||
content_preview = render_mentions(content, user_names)
|
||||
for mentioned_user_id, mention_id in new_mentions_map.items():
|
||||
if mentioned_user_id == user.id:
|
||||
continue # Don't notify yourself
|
||||
await NotificationService.mention.notify_new_mention(
|
||||
session=session,
|
||||
mentioned_user_id=mentioned_user_id,
|
||||
mention_id=mention_id,
|
||||
comment_id=comment_id,
|
||||
message_id=comment.message_id,
|
||||
thread_id=thread.id,
|
||||
thread_title=thread.title or "Untitled thread",
|
||||
author_id=str(user.id),
|
||||
author_name=author_name,
|
||||
content_preview=content_preview[:200],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
author = AuthorResponse(
|
||||
id=user.id,
|
||||
display_name=user.display_name,
|
||||
avatar_url=user.avatar_url,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
return CommentReplyResponse(
|
||||
id=comment.id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=comment.updated_at > comment.created_at,
|
||||
can_edit=True,
|
||||
can_delete=True,
|
||||
)
|
||||
|
||||
|
||||
async def delete_comment(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
user: User,
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a comment (author or user with COMMENTS_DELETE permission).
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
comment_id: ID of the comment to delete
|
||||
user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
Dict with deletion confirmation
|
||||
|
||||
Raises:
|
||||
HTTPException: If comment not found or user lacks permission to delete
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(selectinload(ChatComment.message).selectinload(NewChatMessage.thread))
|
||||
.filter(ChatComment.id == comment_id)
|
||||
)
|
||||
comment = result.scalars().first()
|
||||
|
||||
if not comment:
|
||||
raise HTTPException(status_code=404, detail="Comment not found")
|
||||
|
||||
is_author = comment.author_id == user.id
|
||||
|
||||
# Check if user has COMMENTS_DELETE permission
|
||||
search_space_id = comment.message.thread.search_space_id
|
||||
user_permissions = await get_user_permissions(session, user.id, search_space_id)
|
||||
can_delete_any = has_permission(user_permissions, Permission.COMMENTS_DELETE.value)
|
||||
|
||||
if not is_author and not can_delete_any:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You do not have permission to delete this comment",
|
||||
)
|
||||
|
||||
await session.delete(comment)
|
||||
await session.commit()
|
||||
|
||||
return {"message": "Comment deleted successfully", "comment_id": comment_id}
|
||||
|
||||
|
||||
async def get_user_mentions(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
search_space_id: int | None = None,
|
||||
) -> MentionListResponse:
|
||||
"""
|
||||
Get mentions for the current user, optionally filtered by search space.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user: The current authenticated user
|
||||
search_space_id: Optional search space ID to filter mentions
|
||||
|
||||
Returns:
|
||||
MentionListResponse with mentions and total count
|
||||
"""
|
||||
# Build query with joins for filtering by search_space_id
|
||||
query = (
|
||||
select(ChatCommentMention)
|
||||
.join(ChatComment, ChatCommentMention.comment_id == ChatComment.id)
|
||||
.join(NewChatMessage, ChatComment.message_id == NewChatMessage.id)
|
||||
.join(NewChatThread, NewChatMessage.thread_id == NewChatThread.id)
|
||||
.options(
|
||||
selectinload(ChatCommentMention.comment).selectinload(ChatComment.author),
|
||||
selectinload(ChatCommentMention.comment).selectinload(ChatComment.message),
|
||||
)
|
||||
.filter(ChatCommentMention.mentioned_user_id == user.id)
|
||||
.order_by(ChatCommentMention.created_at.desc())
|
||||
)
|
||||
|
||||
if search_space_id is not None:
|
||||
query = query.filter(NewChatThread.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
mention_records = result.scalars().all()
|
||||
|
||||
# Fetch search space info for context (single query for all unique search spaces)
|
||||
thread_ids = {m.comment.message.thread_id for m in mention_records}
|
||||
if thread_ids:
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.search_space))
|
||||
.filter(NewChatThread.id.in_(thread_ids))
|
||||
)
|
||||
threads_map = {t.id: t for t in thread_result.scalars().all()}
|
||||
else:
|
||||
threads_map = {}
|
||||
|
||||
mentions = []
|
||||
for mention in mention_records:
|
||||
comment = mention.comment
|
||||
message = comment.message
|
||||
thread = threads_map.get(message.thread_id)
|
||||
search_space = thread.search_space if thread else None
|
||||
|
||||
author = None
|
||||
if comment.author:
|
||||
author = AuthorResponse(
|
||||
id=comment.author.id,
|
||||
display_name=comment.author.display_name,
|
||||
avatar_url=comment.author.avatar_url,
|
||||
email=comment.author.email,
|
||||
)
|
||||
|
||||
content_preview = (
|
||||
comment.content[:100] + "..."
|
||||
if len(comment.content) > 100
|
||||
else comment.content
|
||||
)
|
||||
|
||||
mentions.append(
|
||||
MentionResponse(
|
||||
id=mention.id,
|
||||
created_at=mention.created_at,
|
||||
comment=MentionCommentResponse(
|
||||
id=comment.id,
|
||||
content_preview=content_preview,
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
),
|
||||
context=MentionContextResponse(
|
||||
thread_id=thread.id if thread else 0,
|
||||
thread_title=thread.title or "Untitled" if thread else "Unknown",
|
||||
message_id=message.id,
|
||||
search_space_id=search_space.id if search_space else 0,
|
||||
search_space_name=search_space.name if search_space else "Unknown",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return MentionListResponse(
|
||||
mentions=mentions,
|
||||
total_count=len(mentions),
|
||||
)
|
||||
735
surfsense_backend/app/services/notification_service.py
Normal file
735
surfsense_backend/app/services/notification_service.py
Normal file
|
|
@ -0,0 +1,735 @@
|
|||
"""Service for creating and managing notifications with Electric SQL sync."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseNotificationHandler:
|
||||
"""Base class for notification handlers - provides common functionality."""
|
||||
|
||||
def __init__(self, notification_type: str):
|
||||
"""
|
||||
Initialize the notification handler.
|
||||
|
||||
Args:
|
||||
notification_type: Type of notification (e.g., 'connector_indexing', 'document_processing')
|
||||
"""
|
||||
self.notification_type = notification_type
|
||||
|
||||
async def find_notification_by_operation(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
operation_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> Notification | None:
|
||||
"""
|
||||
Find an existing notification by operation ID.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
operation_id: Unique operation identifier
|
||||
search_space_id: Optional search space ID
|
||||
|
||||
Returns:
|
||||
Notification if found, None otherwise
|
||||
"""
|
||||
query = select(Notification).where(
|
||||
Notification.user_id == user_id,
|
||||
Notification.type == self.notification_type,
|
||||
Notification.notification_metadata["operation_id"].astext == operation_id,
|
||||
)
|
||||
if search_space_id is not None:
|
||||
query = query.where(Notification.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def find_or_create_notification(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
operation_id: str,
|
||||
title: str,
|
||||
message: str,
|
||||
search_space_id: int | None = None,
|
||||
initial_metadata: dict[str, Any] | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Find an existing notification or create a new one.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
operation_id: Unique operation identifier
|
||||
title: Notification title
|
||||
message: Notification message
|
||||
search_space_id: Optional search space ID
|
||||
initial_metadata: Initial metadata dictionary
|
||||
|
||||
Returns:
|
||||
Notification: The found or created notification
|
||||
"""
|
||||
# Try to find existing notification
|
||||
notification = await self.find_notification_by_operation(
|
||||
session, user_id, operation_id, search_space_id
|
||||
)
|
||||
|
||||
if notification:
|
||||
# Update existing notification
|
||||
notification.title = title
|
||||
notification.message = message
|
||||
if initial_metadata:
|
||||
notification.notification_metadata = {
|
||||
**notification.notification_metadata,
|
||||
**initial_metadata,
|
||||
}
|
||||
# Mark JSONB column as modified so SQLAlchemy detects the change
|
||||
flag_modified(notification, "notification_metadata")
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(
|
||||
f"Updated notification {notification.id} for operation {operation_id}"
|
||||
)
|
||||
return notification
|
||||
|
||||
# Create new notification
|
||||
metadata = initial_metadata or {}
|
||||
metadata["operation_id"] = operation_id
|
||||
metadata["status"] = "in_progress"
|
||||
metadata["started_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
type=self.notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
notification_metadata=metadata,
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(
|
||||
f"Created notification {notification.id} for operation {operation_id}"
|
||||
)
|
||||
return notification
|
||||
|
||||
async def update_notification(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
notification: Notification,
|
||||
title: str | None = None,
|
||||
message: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata_updates: dict[str, Any] | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update an existing notification.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
title: New title (optional)
|
||||
message: New message (optional)
|
||||
status: New status (optional)
|
||||
metadata_updates: Additional metadata to merge (optional)
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
"""
|
||||
if title is not None:
|
||||
notification.title = title
|
||||
if message is not None:
|
||||
notification.message = message
|
||||
|
||||
if status is not None:
|
||||
notification.notification_metadata["status"] = status
|
||||
if status in ("completed", "failed"):
|
||||
notification.notification_metadata["completed_at"] = datetime.now(
|
||||
UTC
|
||||
).isoformat()
|
||||
# Mark JSONB column as modified so SQLAlchemy detects the change
|
||||
flag_modified(notification, "notification_metadata")
|
||||
|
||||
if metadata_updates:
|
||||
notification.notification_metadata = {
|
||||
**notification.notification_metadata,
|
||||
**metadata_updates,
|
||||
}
|
||||
# Mark JSONB column as modified
|
||||
flag_modified(notification, "notification_metadata")
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(f"Updated notification {notification.id}")
|
||||
return notification
|
||||
|
||||
|
||||
class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
||||
"""Handler for connector indexing notifications."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("connector_indexing")
|
||||
|
||||
def _generate_operation_id(
|
||||
self,
|
||||
connector_id: int,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique operation ID for a connector indexing operation.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID
|
||||
start_date: Start date (optional)
|
||||
end_date: End date (optional)
|
||||
|
||||
Returns:
|
||||
Unique operation ID string
|
||||
"""
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
date_range = ""
|
||||
if start_date or end_date:
|
||||
date_range = f"_{start_date or 'none'}_{end_date or 'none'}"
|
||||
return f"connector_{connector_id}_{timestamp}{date_range}"
|
||||
|
||||
def _generate_google_drive_operation_id(
|
||||
self, connector_id: int, folder_count: int, file_count: int
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique operation ID for a Google Drive indexing operation.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID
|
||||
folder_count: Number of folders to index
|
||||
file_count: Number of files to index
|
||||
|
||||
Returns:
|
||||
Unique operation ID string
|
||||
"""
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
items_info = f"_{folder_count}f_{file_count}files"
|
||||
return f"drive_{connector_id}_{timestamp}{items_info}"
|
||||
|
||||
async def notify_indexing_started(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
connector_type: str,
|
||||
search_space_id: int,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Create or update notification when connector indexing starts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
connector_id: Connector ID
|
||||
connector_name: Connector name
|
||||
connector_type: Connector type
|
||||
search_space_id: Search space ID
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
|
||||
Returns:
|
||||
Notification: The created or updated notification
|
||||
"""
|
||||
operation_id = self._generate_operation_id(connector_id, start_date, end_date)
|
||||
title = f"Syncing: {connector_name}"
|
||||
message = "Connecting to your account"
|
||||
|
||||
metadata = {
|
||||
"connector_id": connector_id,
|
||||
"connector_name": connector_name,
|
||||
"connector_type": connector_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"indexed_count": 0,
|
||||
"sync_stage": "connecting",
|
||||
}
|
||||
|
||||
return await self.find_or_create_notification(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
title=title,
|
||||
message=message,
|
||||
search_space_id=search_space_id,
|
||||
initial_metadata=metadata,
|
||||
)
|
||||
|
||||
async def notify_indexing_progress(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
notification: Notification,
|
||||
indexed_count: int,
|
||||
total_count: int | None = None,
|
||||
stage: str | None = None,
|
||||
stage_message: str | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification with indexing progress.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
indexed_count: Number of items indexed so far
|
||||
total_count: Total number of items (optional)
|
||||
stage: Current sync stage (fetching, processing, storing) (optional)
|
||||
stage_message: Optional custom message for the stage
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
"""
|
||||
# User-friendly stage messages (clean, no ellipsis - spinner shows activity)
|
||||
stage_messages = {
|
||||
"connecting": "Connecting to your account",
|
||||
"fetching": "Fetching your content",
|
||||
"processing": "Preparing for search",
|
||||
"storing": "Almost done",
|
||||
}
|
||||
|
||||
# Use stage-based message if stage provided, otherwise fallback
|
||||
if stage or stage_message:
|
||||
progress_msg = stage_message or stage_messages.get(stage, "Processing")
|
||||
else:
|
||||
# Fallback for backward compatibility
|
||||
progress_msg = "Fetching your content"
|
||||
|
||||
metadata_updates = {"indexed_count": indexed_count}
|
||||
if total_count is not None:
|
||||
metadata_updates["total_count"] = total_count
|
||||
progress_percent = int((indexed_count / total_count) * 100)
|
||||
metadata_updates["progress_percent"] = progress_percent
|
||||
if stage:
|
||||
metadata_updates["sync_stage"] = stage
|
||||
|
||||
return await self.update_notification(
|
||||
session=session,
|
||||
notification=notification,
|
||||
message=progress_msg,
|
||||
status="in_progress",
|
||||
metadata_updates=metadata_updates,
|
||||
)
|
||||
|
||||
async def notify_indexing_completed(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
notification: Notification,
|
||||
indexed_count: int,
|
||||
error_message: str | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification when connector indexing completes.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
indexed_count: Total number of items indexed
|
||||
error_message: Error message if indexing failed (optional)
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
"""
|
||||
connector_name = notification.notification_metadata.get(
|
||||
"connector_name", "Connector"
|
||||
)
|
||||
|
||||
if error_message:
|
||||
title = f"Failed: {connector_name}"
|
||||
message = f"Sync failed: {error_message}"
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {connector_name}"
|
||||
if indexed_count == 0:
|
||||
message = "Already up to date! No new items to sync."
|
||||
else:
|
||||
item_text = "item" if indexed_count == 1 else "items"
|
||||
message = f"Now searchable! {indexed_count} {item_text} synced."
|
||||
status = "completed"
|
||||
|
||||
metadata_updates = {
|
||||
"indexed_count": indexed_count,
|
||||
"sync_stage": "completed" if not error_message else "failed",
|
||||
"error_message": error_message,
|
||||
}
|
||||
|
||||
return await self.update_notification(
|
||||
session=session,
|
||||
notification=notification,
|
||||
title=title,
|
||||
message=message,
|
||||
status=status,
|
||||
metadata_updates=metadata_updates,
|
||||
)
|
||||
|
||||
async def notify_google_drive_indexing_started(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
connector_type: str,
|
||||
search_space_id: int,
|
||||
folder_count: int,
|
||||
file_count: int,
|
||||
folder_names: list[str] | None = None,
|
||||
file_names: list[str] | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Create or update notification when Google Drive indexing starts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
connector_id: Connector ID
|
||||
connector_name: Connector name
|
||||
connector_type: Connector type
|
||||
search_space_id: Search space ID
|
||||
folder_count: Number of folders to index
|
||||
file_count: Number of files to index
|
||||
folder_names: List of folder names (optional)
|
||||
file_names: List of file names (optional)
|
||||
|
||||
Returns:
|
||||
Notification: The created or updated notification
|
||||
"""
|
||||
operation_id = self._generate_google_drive_operation_id(
|
||||
connector_id, folder_count, file_count
|
||||
)
|
||||
title = f"Syncing: {connector_name}"
|
||||
message = "Preparing your files"
|
||||
|
||||
metadata = {
|
||||
"connector_id": connector_id,
|
||||
"connector_name": connector_name,
|
||||
"connector_type": connector_type,
|
||||
"folder_count": folder_count,
|
||||
"file_count": file_count,
|
||||
"indexed_count": 0,
|
||||
"sync_stage": "connecting",
|
||||
}
|
||||
|
||||
if folder_names:
|
||||
metadata["folder_names"] = folder_names
|
||||
if file_names:
|
||||
metadata["file_names"] = file_names
|
||||
|
||||
return await self.find_or_create_notification(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
title=title,
|
||||
message=message,
|
||||
search_space_id=search_space_id,
|
||||
initial_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class DocumentProcessingNotificationHandler(BaseNotificationHandler):
|
||||
"""Handler for document processing notifications."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("document_processing")
|
||||
|
||||
def _generate_operation_id(
|
||||
self, document_type: str, filename: str, search_space_id: int
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique operation ID for a document processing operation.
|
||||
|
||||
Args:
|
||||
document_type: Type of document (FILE, YOUTUBE_VIDEO, CRAWLED_URL, etc.)
|
||||
filename: Name of the file/document
|
||||
search_space_id: Search space ID
|
||||
|
||||
Returns:
|
||||
Unique operation ID string
|
||||
"""
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f")
|
||||
# Create a short hash of filename to ensure uniqueness
|
||||
import hashlib
|
||||
|
||||
filename_hash = hashlib.md5(filename.encode()).hexdigest()[:8]
|
||||
return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}"
|
||||
|
||||
async def notify_processing_started(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
document_type: str,
|
||||
document_name: str,
|
||||
search_space_id: int,
|
||||
file_size: int | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Create notification when document processing starts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
document_type: Type of document (FILE, YOUTUBE_VIDEO, CRAWLED_URL, etc.)
|
||||
document_name: Name/title of the document
|
||||
search_space_id: Search space ID
|
||||
file_size: Size of file in bytes (optional)
|
||||
|
||||
Returns:
|
||||
Notification: The created notification
|
||||
"""
|
||||
operation_id = self._generate_operation_id(
|
||||
document_type, document_name, search_space_id
|
||||
)
|
||||
title = f"Processing: {document_name}"
|
||||
message = "Waiting in queue"
|
||||
|
||||
metadata = {
|
||||
"document_type": document_type,
|
||||
"document_name": document_name,
|
||||
"processing_stage": "queued",
|
||||
}
|
||||
|
||||
if file_size is not None:
|
||||
metadata["file_size"] = file_size
|
||||
|
||||
return await self.find_or_create_notification(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
title=title,
|
||||
message=message,
|
||||
search_space_id=search_space_id,
|
||||
initial_metadata=metadata,
|
||||
)
|
||||
|
||||
async def notify_processing_progress(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
notification: Notification,
|
||||
stage: str,
|
||||
stage_message: str | None = None,
|
||||
chunks_count: int | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification with processing progress.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
stage: Current processing stage (parsing, chunking, embedding, storing)
|
||||
stage_message: Optional custom message for the stage
|
||||
chunks_count: Number of chunks created (optional, stored in metadata only)
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
"""
|
||||
# User-friendly stage messages
|
||||
stage_messages = {
|
||||
"parsing": "Reading your file",
|
||||
"chunking": "Preparing for search",
|
||||
"embedding": "Preparing for search",
|
||||
"storing": "Finalizing",
|
||||
}
|
||||
|
||||
message = stage_message or stage_messages.get(stage, "Processing")
|
||||
|
||||
metadata_updates = {"processing_stage": stage}
|
||||
# Store chunks_count in metadata for debugging, but don't show to user
|
||||
if chunks_count is not None:
|
||||
metadata_updates["chunks_count"] = chunks_count
|
||||
|
||||
return await self.update_notification(
|
||||
session=session,
|
||||
notification=notification,
|
||||
message=message,
|
||||
status="in_progress",
|
||||
metadata_updates=metadata_updates,
|
||||
)
|
||||
|
||||
async def notify_processing_completed(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
notification: Notification,
|
||||
document_id: int | None = None,
|
||||
chunks_count: int | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification when document processing completes.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
document_id: ID of the created document (optional)
|
||||
chunks_count: Total number of chunks created (optional)
|
||||
error_message: Error message if processing failed (optional)
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
"""
|
||||
document_name = notification.notification_metadata.get(
|
||||
"document_name", "Document"
|
||||
)
|
||||
|
||||
if error_message:
|
||||
title = f"Failed: {document_name}"
|
||||
message = f"Processing failed: {error_message}"
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {document_name}"
|
||||
message = "Now searchable!"
|
||||
status = "completed"
|
||||
|
||||
metadata_updates = {
|
||||
"processing_stage": "completed" if not error_message else "failed",
|
||||
"error_message": error_message,
|
||||
}
|
||||
|
||||
if document_id is not None:
|
||||
metadata_updates["document_id"] = document_id
|
||||
# Store chunks_count in metadata for debugging, but don't show to user
|
||||
if chunks_count is not None:
|
||||
metadata_updates["chunks_count"] = chunks_count
|
||||
|
||||
return await self.update_notification(
|
||||
session=session,
|
||||
notification=notification,
|
||||
title=title,
|
||||
message=message,
|
||||
status=status,
|
||||
metadata_updates=metadata_updates,
|
||||
)
|
||||
|
||||
|
||||
class MentionNotificationHandler(BaseNotificationHandler):
|
||||
"""Handler for new mention notifications."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("new_mention")
|
||||
|
||||
async def notify_new_mention(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
mentioned_user_id: UUID,
|
||||
mention_id: int,
|
||||
comment_id: int,
|
||||
message_id: int,
|
||||
thread_id: int,
|
||||
thread_title: str,
|
||||
author_id: str,
|
||||
author_name: str,
|
||||
content_preview: str,
|
||||
search_space_id: int,
|
||||
) -> Notification:
|
||||
"""
|
||||
Create notification when a user is @mentioned in a comment.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
mentioned_user_id: User who was mentioned
|
||||
mention_id: ID of the mention record
|
||||
comment_id: ID of the comment containing the mention
|
||||
message_id: ID of the message being commented on
|
||||
thread_id: ID of the chat thread
|
||||
thread_title: Title of the chat thread
|
||||
author_id: ID of the comment author
|
||||
author_name: Display name of the comment author
|
||||
content_preview: First ~100 chars of the comment
|
||||
search_space_id: Search space ID
|
||||
|
||||
Returns:
|
||||
Notification: The created notification
|
||||
"""
|
||||
title = f"{author_name} mentioned you"
|
||||
message = content_preview[:100] + ("..." if len(content_preview) > 100 else "")
|
||||
|
||||
metadata = {
|
||||
"mention_id": mention_id,
|
||||
"comment_id": comment_id,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"thread_title": thread_title,
|
||||
"author_id": author_id,
|
||||
"author_name": author_name,
|
||||
"content_preview": content_preview[:200],
|
||||
}
|
||||
|
||||
notification = Notification(
|
||||
user_id=mentioned_user_id,
|
||||
search_space_id=search_space_id,
|
||||
type=self.notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
notification_metadata=metadata,
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(
|
||||
f"Created new_mention notification {notification.id} for user {mentioned_user_id}"
|
||||
)
|
||||
return notification
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""Service for creating and managing notifications that sync via Electric SQL."""
|
||||
|
||||
# Handler instances
|
||||
connector_indexing = ConnectorIndexingNotificationHandler()
|
||||
document_processing = DocumentProcessingNotificationHandler()
|
||||
mention = MentionNotificationHandler()
|
||||
|
||||
@staticmethod
|
||||
async def create_notification(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
notification_type: str,
|
||||
title: str,
|
||||
message: str,
|
||||
search_space_id: int | None = None,
|
||||
notification_metadata: dict[str, Any] | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Create a notification - Electric SQL will automatically sync it to frontend.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User to notify
|
||||
notification_type: Type of notification (e.g., 'document_processing', 'connector_indexing')
|
||||
title: Notification title
|
||||
message: Notification message
|
||||
search_space_id: Optional search space ID
|
||||
notification_metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
Notification: The created notification
|
||||
"""
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
type=notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
notification_metadata=notification_metadata or {},
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(f"Created notification {notification.id} for user {user_id}")
|
||||
return notification
|
||||
|
|
@ -445,31 +445,13 @@ async def _index_google_gmail_messages(
|
|||
end_date: str,
|
||||
):
|
||||
"""Index Google Gmail messages with new session."""
|
||||
from datetime import datetime
|
||||
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_google_gmail_indexing,
|
||||
)
|
||||
|
||||
# Parse dates to calculate days_back
|
||||
max_messages = 100
|
||||
days_back = 30 # Default
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
# Parse start_date (format: YYYY-MM-DD)
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
# Calculate days back from now
|
||||
days_back = (datetime.now() - start_dt).days
|
||||
# Ensure at least 1 day
|
||||
days_back = max(1, days_back)
|
||||
except ValueError:
|
||||
# If parsing fails, use default
|
||||
days_back = 30
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_google_gmail_indexing(
|
||||
session, connector_id, search_space_id, user_id, max_messages, days_back
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -479,7 +461,7 @@ def index_google_drive_files_task(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict, # Dictionary with 'folders' and 'files' lists
|
||||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||
):
|
||||
"""Celery task to index Google Drive folders and files."""
|
||||
import asyncio
|
||||
|
|
@ -504,7 +486,7 @@ async def _index_google_drive_files(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict, # Dictionary with 'folders' and 'files' lists
|
||||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||
):
|
||||
"""Index Google Drive folders and files with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
"""Celery tasks for document processing."""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.document_processors import (
|
||||
add_extension_received_document,
|
||||
|
|
@ -84,6 +86,22 @@ async def _process_extension_document(
|
|||
async with get_celery_session_maker()() as session:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Truncate title for notification display
|
||||
page_title = individual_document.metadata.VisitedWebPageTitle[:50]
|
||||
if len(individual_document.metadata.VisitedWebPageTitle) > 50:
|
||||
page_title += "..."
|
||||
|
||||
# Create notification for document processing
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="EXTENSION",
|
||||
document_name=page_title,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_extension_document",
|
||||
source="document_processor",
|
||||
|
|
@ -97,6 +115,14 @@ async def _process_extension_document(
|
|||
)
|
||||
|
||||
try:
|
||||
# Update notification: parsing stage
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Reading page content",
|
||||
)
|
||||
|
||||
result = await add_extension_received_document(
|
||||
session, individual_document, search_space_id, user_id
|
||||
)
|
||||
|
|
@ -107,12 +133,31 @@ async def _process_extension_document(
|
|||
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash},
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
|
||||
# Update notification for duplicate
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Page already saved (duplicate)",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
|
|
@ -120,6 +165,23 @@ async def _process_extension_document(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
|
||||
try:
|
||||
# Refresh notification to ensure it's not stale after any rollback
|
||||
await session.refresh(notification)
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:100],
|
||||
)
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
)
|
||||
|
||||
logger.error(f"Error processing extension document: {e!s}")
|
||||
raise
|
||||
|
||||
|
|
@ -150,6 +212,20 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
async with get_celery_session_maker()() as session:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Extract video title from URL for notification (will be updated later)
|
||||
video_name = url.split("v=")[-1][:11] if "v=" in url else url
|
||||
|
||||
# Create notification for document processing
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="YOUTUBE_VIDEO",
|
||||
document_name=f"YouTube: {video_name}",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_youtube_video",
|
||||
source="document_processor",
|
||||
|
|
@ -158,6 +234,14 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
)
|
||||
|
||||
try:
|
||||
# Update notification: parsing (fetching transcript)
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Fetching video transcript",
|
||||
)
|
||||
|
||||
result = await add_youtube_video_document(
|
||||
session, url, search_space_id, user_id
|
||||
)
|
||||
|
|
@ -172,12 +256,31 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"YouTube video document already exists (duplicate): {url}",
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
|
||||
# Update notification for duplicate
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Video already exists (duplicate)",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
|
|
@ -185,6 +288,23 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
|
||||
try:
|
||||
# Refresh notification to ensure it's not stale after any rollback
|
||||
await session.refresh(notification)
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:100],
|
||||
)
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
)
|
||||
|
||||
logger.error(f"Error processing YouTube video: {e!s}")
|
||||
raise
|
||||
|
||||
|
|
@ -219,11 +339,31 @@ async def _process_file_upload(
|
|||
file_path: str, filename: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Process file upload with new session."""
|
||||
import os
|
||||
|
||||
from app.tasks.document_processors.file_processors import process_file_in_background
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Get file size for notification metadata
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
except Exception:
|
||||
file_size = None
|
||||
|
||||
# Create notification for document processing
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="FILE",
|
||||
document_name=filename,
|
||||
search_space_id=search_space_id,
|
||||
file_size=file_size,
|
||||
)
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload",
|
||||
source="document_processor",
|
||||
|
|
@ -237,7 +377,7 @@ async def _process_file_upload(
|
|||
)
|
||||
|
||||
try:
|
||||
await process_file_in_background(
|
||||
result = await process_file_in_background(
|
||||
file_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
|
|
@ -245,7 +385,29 @@ async def _process_file_upload(
|
|||
session,
|
||||
task_logger,
|
||||
log_entry,
|
||||
notification=notification,
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
if result:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Duplicate detected
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Document already exists (duplicate)",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Import here to avoid circular dependencies
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -258,7 +420,23 @@ async def _process_file_upload(
|
|||
elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower():
|
||||
error_message = str(e.detail)
|
||||
else:
|
||||
error_message = f"Failed to process file: {filename}"
|
||||
error_message = str(e)[:100]
|
||||
|
||||
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
|
||||
try:
|
||||
# Refresh notification to ensure it's not stale after any rollback
|
||||
await session.refresh(notification)
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=error_message,
|
||||
)
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
)
|
||||
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
|
|
@ -323,6 +501,22 @@ async def _process_circleback_meeting(
|
|||
async with get_celery_session_maker()() as session:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
# Get user_id from metadata if available
|
||||
user_id = metadata.get("user_id")
|
||||
|
||||
# Create notification if user_id is available
|
||||
notification = None
|
||||
if user_id:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="CIRCLEBACK",
|
||||
document_name=f"Meeting: {meeting_name[:40]}",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_circleback_meeting",
|
||||
source="circleback_webhook",
|
||||
|
|
@ -336,6 +530,17 @@ async def _process_circleback_meeting(
|
|||
)
|
||||
|
||||
try:
|
||||
# Update notification: parsing stage
|
||||
if notification:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Reading meeting notes",
|
||||
)
|
||||
)
|
||||
|
||||
result = await add_circleback_meeting_document(
|
||||
session=session,
|
||||
meeting_id=meeting_id,
|
||||
|
|
@ -355,12 +560,29 @@ async def _process_circleback_meeting(
|
|||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Circleback meeting document already exists (duplicate): {meeting_name}",
|
||||
{"duplicate_detected": True, "meeting_id": meeting_id},
|
||||
)
|
||||
|
||||
# Update notification for duplicate
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Meeting already saved (duplicate)",
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
|
|
@ -368,5 +590,21 @@ async def _process_circleback_meeting(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__, "meeting_id": meeting_id},
|
||||
)
|
||||
|
||||
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
|
||||
if notification:
|
||||
try:
|
||||
# Refresh notification to ensure it's not stale after any rollback
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:100],
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
)
|
||||
|
||||
logger.error(f"Error processing Circleback meeting: {e!s}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ async def _check_and_trigger_schedules():
|
|||
index_elasticsearch_documents_task,
|
||||
index_github_repos_task,
|
||||
index_google_calendar_events_task,
|
||||
index_google_drive_files_task,
|
||||
index_google_gmail_messages_task,
|
||||
index_jira_issues_task,
|
||||
index_linear_issues_task,
|
||||
|
|
@ -96,6 +97,7 @@ async def _check_and_trigger_schedules():
|
|||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
}
|
||||
|
||||
# Trigger indexing for each due connector
|
||||
|
|
@ -106,13 +108,57 @@ async def _check_and_trigger_schedules():
|
|||
f"Triggering periodic indexing for connector {connector.id} "
|
||||
f"({connector.connector_type.value})"
|
||||
)
|
||||
task.delay(
|
||||
connector.id,
|
||||
connector.search_space_id,
|
||||
str(connector.user_id),
|
||||
None, # start_date - uses last_indexed_at
|
||||
None, # end_date - uses now
|
||||
)
|
||||
|
||||
# Special handling for Google Drive - uses config for folder/file selection
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
connector_config = connector.config or {}
|
||||
selected_folders = connector_config.get("selected_folders", [])
|
||||
selected_files = connector_config.get("selected_files", [])
|
||||
indexing_options = connector_config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
|
||||
if selected_folders or selected_files:
|
||||
task.delay(
|
||||
connector.id,
|
||||
connector.search_space_id,
|
||||
str(connector.user_id),
|
||||
{
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No folders/files selected - skip indexing but still update next_scheduled_at
|
||||
# to prevent checking every minute
|
||||
logger.info(
|
||||
f"Google Drive connector {connector.id} has no folders or files selected, "
|
||||
"skipping periodic indexing (will check again at next scheduled time)"
|
||||
)
|
||||
from datetime import timedelta
|
||||
|
||||
connector.next_scheduled_at = now + timedelta(
|
||||
minutes=connector.indexing_frequency_minutes
|
||||
)
|
||||
await session.commit()
|
||||
continue
|
||||
else:
|
||||
task.delay(
|
||||
connector.id,
|
||||
connector.search_space_id,
|
||||
str(connector.user_id),
|
||||
None, # start_date - uses last_indexed_at
|
||||
None, # end_date - uses now
|
||||
)
|
||||
|
||||
# Update next_scheduled_at for next run
|
||||
from datetime import timedelta
|
||||
|
|
|
|||
|
|
@ -423,9 +423,9 @@ async def stream_new_chat(
|
|||
title = title[:27] + "..."
|
||||
doc_names.append(title)
|
||||
if len(doc_names) == 1:
|
||||
processing_parts.append(f"[📖 {doc_names[0]}]")
|
||||
processing_parts.append(f"[{doc_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[📖 {len(doc_names)} docs]")
|
||||
processing_parts.append(f"[{len(doc_names)} docs]")
|
||||
|
||||
last_active_step_items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||
|
||||
|
|
|
|||
|
|
@ -549,7 +549,10 @@ async def index_discord_messages(
|
|||
logger.info(
|
||||
f"Discord indexing completed: {documents_indexed} new messages, {documents_skipped} skipped"
|
||||
)
|
||||
return documents_indexed, result_message
|
||||
return (
|
||||
documents_indexed,
|
||||
None,
|
||||
) # Return None on success (result_message is for logging only)
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ async def index_google_drive_files(
|
|||
use_delta_sync: bool = True,
|
||||
update_last_indexed: bool = True,
|
||||
max_files: int = 500,
|
||||
include_subfolders: bool = False,
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Google Drive files for a specific connector.
|
||||
|
|
@ -51,6 +52,7 @@ async def index_google_drive_files(
|
|||
use_delta_sync: Whether to use change tracking for incremental sync
|
||||
update_last_indexed: Whether to update last_indexed_at timestamp
|
||||
max_files: Maximum number of files to index
|
||||
include_subfolders: Whether to recursively index files in subfolders
|
||||
|
||||
Returns:
|
||||
Tuple of (number_of_indexed_files, error_message)
|
||||
|
|
@ -144,6 +146,7 @@ async def index_google_drive_files(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
max_files=max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Using full scan for connector {connector_id}")
|
||||
|
|
@ -159,6 +162,7 @@ async def index_google_drive_files(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
max_files=max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
)
|
||||
|
||||
documents_indexed, documents_skipped = result
|
||||
|
|
@ -168,6 +172,9 @@ async def index_google_drive_files(
|
|||
if new_token and not token_error:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
# Refresh connector to reload attributes that may have been expired by earlier commits
|
||||
await session.refresh(connector)
|
||||
|
||||
if "folder_tokens" not in connector.config:
|
||||
connector.config["folder_tokens"] = {}
|
||||
connector.config["folder_tokens"][target_folder_id] = new_token
|
||||
|
|
@ -375,60 +382,89 @@ async def _index_full_scan(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: any,
|
||||
max_files: int,
|
||||
include_subfolders: bool = False,
|
||||
) -> tuple[int, int]:
|
||||
"""Perform full scan indexing of a folder."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
{"stage": "full_scan", "folder_id": folder_id},
|
||||
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
|
||||
{
|
||||
"stage": "full_scan",
|
||||
"folder_id": folder_id,
|
||||
"include_subfolders": include_subfolders,
|
||||
},
|
||||
)
|
||||
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
page_token = None
|
||||
files_processed = 0
|
||||
|
||||
while files_processed < max_files:
|
||||
files, next_token, error = await get_files_in_folder(
|
||||
drive_client, folder_id, include_subfolders=False, page_token=page_token
|
||||
)
|
||||
# Queue of folders to process: (folder_id, folder_name)
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
|
||||
if error:
|
||||
logger.error(f"Error listing files: {error}")
|
||||
break
|
||||
while folders_to_process and files_processed < max_files:
|
||||
current_folder_id, current_folder_name = folders_to_process.pop(0)
|
||||
logger.info(f"Processing folder: {current_folder_name} ({current_folder_id})")
|
||||
page_token = None
|
||||
|
||||
if not files:
|
||||
break
|
||||
|
||||
for file in files:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
|
||||
files_processed += 1
|
||||
|
||||
indexed, skipped = await _process_single_file(
|
||||
drive_client=drive_client,
|
||||
session=session,
|
||||
file=file,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
while files_processed < max_files:
|
||||
# Get files and folders in current folder
|
||||
# include_subfolders=True here so we get folder items to queue them
|
||||
files, next_token, error = await get_files_in_folder(
|
||||
drive_client,
|
||||
current_folder_id,
|
||||
include_subfolders=True,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
documents_indexed += indexed
|
||||
documents_skipped += skipped
|
||||
if error:
|
||||
logger.error(f"Error listing files in {current_folder_name}: {error}")
|
||||
break
|
||||
|
||||
if documents_indexed % 10 == 0 and documents_indexed > 0:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Committed batch: {documents_indexed} files indexed so far"
|
||||
if not files:
|
||||
break
|
||||
|
||||
for file in files:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
|
||||
mime_type = file.get("mimeType", "")
|
||||
|
||||
# If this is a folder and include_subfolders is enabled, queue it for processing
|
||||
if mime_type == "application/vnd.google-apps.folder":
|
||||
if include_subfolders:
|
||||
folders_to_process.append(
|
||||
(file["id"], file.get("name", "Unknown"))
|
||||
)
|
||||
logger.debug(f"Queued subfolder: {file.get('name', 'Unknown')}")
|
||||
continue
|
||||
|
||||
# Process the file
|
||||
files_processed += 1
|
||||
|
||||
indexed, skipped = await _process_single_file(
|
||||
drive_client=drive_client,
|
||||
session=session,
|
||||
file=file,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
)
|
||||
|
||||
page_token = next_token
|
||||
if not page_token:
|
||||
break
|
||||
documents_indexed += indexed
|
||||
documents_skipped += skipped
|
||||
|
||||
if documents_indexed % 10 == 0 and documents_indexed > 0:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Committed batch: {documents_indexed} files indexed so far"
|
||||
)
|
||||
|
||||
page_token = next_token
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Full scan complete: {documents_indexed} indexed, {documents_skipped} skipped"
|
||||
|
|
@ -448,8 +484,13 @@ async def _index_with_delta_sync(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: any,
|
||||
max_files: int,
|
||||
include_subfolders: bool = False,
|
||||
) -> tuple[int, int]:
|
||||
"""Perform delta sync indexing using change tracking."""
|
||||
"""Perform delta sync indexing using change tracking.
|
||||
|
||||
Note: include_subfolders is accepted for API consistency but delta sync
|
||||
automatically tracks changes across all folders including subfolders.
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting delta sync from token: {start_page_token[:20]}...",
|
||||
|
|
@ -515,6 +556,131 @@ async def _index_with_delta_sync(
|
|||
return documents_indexed, documents_skipped
|
||||
|
||||
|
||||
async def _check_rename_only_update(
|
||||
session: AsyncSession,
|
||||
file: dict,
|
||||
search_space_id: int,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if a file only needs a rename update (no content change).
|
||||
|
||||
Uses md5Checksum comparison (preferred) or modifiedTime (fallback for Google Workspace files)
|
||||
to detect if content has changed. This optimization prevents unnecessary ETL API calls
|
||||
(Docling/LlamaCloud) for rename-only operations.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
file: File metadata from Google Drive API
|
||||
search_space_id: ID of the search space
|
||||
|
||||
Returns:
|
||||
Tuple of (is_rename_only, message)
|
||||
- (True, message): Only filename changed, document was updated
|
||||
- (False, None): Content changed or new file, needs full processing
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
incoming_md5 = file.get("md5Checksum") # None for Google Workspace files
|
||||
incoming_modified_time = file.get("modifiedTime")
|
||||
|
||||
if not file_id:
|
||||
return False, None
|
||||
|
||||
# Try to find existing document by file_id-based hash (primary method)
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
# If not found by primary hash, try searching by metadata (for legacy documents)
|
||||
if not existing_document:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
Document.document_metadata["google_drive_file_id"].astext == file_id,
|
||||
)
|
||||
)
|
||||
existing_document = result.scalar_one_or_none()
|
||||
if existing_document:
|
||||
logger.debug(f"Found legacy document by metadata for file_id: {file_id}")
|
||||
|
||||
if not existing_document:
|
||||
# New file, needs full processing
|
||||
return False, None
|
||||
|
||||
# Get stored checksums/timestamps from document metadata
|
||||
doc_metadata = existing_document.document_metadata or {}
|
||||
stored_md5 = doc_metadata.get("md5_checksum")
|
||||
stored_modified_time = doc_metadata.get("modified_time")
|
||||
|
||||
# Determine if content changed using md5Checksum (preferred) or modifiedTime (fallback)
|
||||
content_unchanged = False
|
||||
|
||||
if incoming_md5 and stored_md5:
|
||||
# Best case: Compare md5 checksums (only changes when content changes, not on rename)
|
||||
content_unchanged = incoming_md5 == stored_md5
|
||||
logger.debug(f"MD5 comparison for {file_name}: unchanged={content_unchanged}")
|
||||
elif incoming_md5 and not stored_md5:
|
||||
# Have incoming md5 but no stored md5 (legacy doc) - need to reprocess to store it
|
||||
logger.debug(
|
||||
f"No stored md5 for {file_name}, will reprocess to store md5_checksum"
|
||||
)
|
||||
return False, None
|
||||
elif not incoming_md5:
|
||||
# Google Workspace file (no md5Checksum available) - fall back to modifiedTime
|
||||
# Note: modifiedTime is less reliable as it changes on rename too, but it's the best we have
|
||||
if incoming_modified_time and stored_modified_time:
|
||||
content_unchanged = incoming_modified_time == stored_modified_time
|
||||
logger.debug(
|
||||
f"ModifiedTime fallback for Google Workspace file {file_name}: unchanged={content_unchanged}"
|
||||
)
|
||||
else:
|
||||
# No stored modifiedTime (legacy) - reprocess to store it
|
||||
return False, None
|
||||
|
||||
if content_unchanged:
|
||||
# Content hasn't changed - check if filename changed
|
||||
old_name = doc_metadata.get("FILE_NAME") or doc_metadata.get(
|
||||
"google_drive_file_name"
|
||||
)
|
||||
|
||||
if old_name and old_name != file_name:
|
||||
# Rename-only update - update the document without re-processing
|
||||
existing_document.title = file_name
|
||||
if not existing_document.document_metadata:
|
||||
existing_document.document_metadata = {}
|
||||
existing_document.document_metadata["FILE_NAME"] = file_name
|
||||
existing_document.document_metadata["google_drive_file_name"] = file_name
|
||||
# Also update modified_time for Google Workspace files (since it changed on rename)
|
||||
if incoming_modified_time:
|
||||
existing_document.document_metadata["modified_time"] = (
|
||||
incoming_modified_time
|
||||
)
|
||||
flag_modified(existing_document, "document_metadata")
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Rename-only update: '{old_name}' → '{file_name}' (skipped ETL)"
|
||||
)
|
||||
return (
|
||||
True,
|
||||
f"File renamed: '{old_name}' → '{file_name}' (no content change)",
|
||||
)
|
||||
else:
|
||||
# Neither content nor name changed
|
||||
logger.debug(f"File unchanged: {file_name}")
|
||||
return True, "File unchanged (same content and name)"
|
||||
|
||||
# Content changed - needs full processing
|
||||
return False, None
|
||||
|
||||
|
||||
async def _process_single_file(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
|
|
@ -537,6 +703,27 @@ async def _process_single_file(
|
|||
try:
|
||||
logger.info(f"Processing file: {file_name} ({mime_type})")
|
||||
|
||||
# Early check: Is this a rename-only update?
|
||||
# This optimization prevents downloading and ETL processing for files
|
||||
# where only the name changed but content is the same.
|
||||
is_rename_only, rename_message = await _check_rename_only_update(
|
||||
session=session,
|
||||
file=file,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if is_rename_only:
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Skipped ETL for {file_name}: {rename_message}",
|
||||
{"status": "rename_only", "reason": rename_message},
|
||||
)
|
||||
# Return 1 for renamed files (they are "indexed" in the sense that they're updated)
|
||||
# Return 0 for unchanged files
|
||||
if "renamed" in (rename_message or "").lower():
|
||||
return 1, 0
|
||||
return 0, 1
|
||||
|
||||
_, error, _ = await download_and_process_file(
|
||||
client=drive_client,
|
||||
file=file,
|
||||
|
|
@ -564,7 +751,15 @@ async def _process_single_file(
|
|||
|
||||
|
||||
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
|
||||
"""Remove a document that was deleted in Drive."""
|
||||
"""Remove a document that was deleted in Drive.
|
||||
|
||||
Handles both new (file_id-based) and legacy (filename-based) hash schemes.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document
|
||||
|
||||
# First try with file_id-based hash (new method)
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
|
|
@ -573,6 +768,19 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
|
|||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# If not found, search by metadata (for legacy documents with filename-based hash)
|
||||
if not existing_document:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
Document.document_metadata["google_drive_file_id"].astext == file_id,
|
||||
)
|
||||
)
|
||||
existing_document = result.scalar_one_or_none()
|
||||
if existing_document:
|
||||
logger.info(f"Found legacy document by metadata for file_id: {file_id}")
|
||||
|
||||
if existing_document:
|
||||
await session.delete(existing_document)
|
||||
logger.info(f"Removed deleted file document: {file_id}")
|
||||
|
|
|
|||
|
|
@ -464,7 +464,10 @@ async def index_notion_pages(
|
|||
# Clean up the async client
|
||||
await notion_client.close()
|
||||
|
||||
return total_processed, result_message
|
||||
return (
|
||||
total_processed,
|
||||
None,
|
||||
) # Return None on success (result_message is for logging only)
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -413,7 +413,10 @@ async def index_slack_messages(
|
|||
logger.info(
|
||||
f"Slack indexing completed: {documents_indexed} new channels, {documents_skipped} skipped"
|
||||
)
|
||||
return total_processed, result_message
|
||||
return (
|
||||
total_processed,
|
||||
None,
|
||||
) # Return None on success (result_message is for logging only)
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -460,7 +460,10 @@ async def index_teams_messages(
|
|||
documents_indexed,
|
||||
documents_skipped,
|
||||
)
|
||||
return total_processed, result_message
|
||||
return (
|
||||
total_processed,
|
||||
None,
|
||||
) # Return None on success (result_message is for logging only)
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -371,17 +371,14 @@ async def index_crawled_urls(
|
|||
)
|
||||
await session.commit()
|
||||
|
||||
# Build result message
|
||||
result_message = None
|
||||
# Log failed URLs if any (for debugging purposes)
|
||||
if failed_urls:
|
||||
failed_summary = "; ".join(
|
||||
[f"{url}: {error}" for url, error in failed_urls[:5]]
|
||||
)
|
||||
if len(failed_urls) > 5:
|
||||
failed_summary += f" (and {len(failed_urls) - 5} more)"
|
||||
result_message = (
|
||||
f"Completed with {len(failed_urls)} failures: {failed_summary}"
|
||||
)
|
||||
logger.warning(f"Some URLs failed to index: {failed_summary}")
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -400,7 +397,10 @@ async def index_crawled_urls(
|
|||
f"{documents_updated} updated, {documents_skipped} skipped, "
|
||||
f"{len(failed_urls)} failed"
|
||||
)
|
||||
return total_processed, result_message
|
||||
return (
|
||||
total_processed,
|
||||
None,
|
||||
) # Return None on success (result_message is for logging only)
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentType, Log
|
||||
from app.db import Document, DocumentType, Log, Notification
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
convert_document_to_markdown,
|
||||
|
|
@ -30,6 +31,7 @@ from app.utils.document_converters import (
|
|||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
)
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
|
@ -48,6 +50,160 @@ LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
|||
)
|
||||
|
||||
|
||||
def get_google_drive_unique_identifier(
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
) -> tuple[str, str | None]:
|
||||
"""
|
||||
Get unique identifier hash for a file, with special handling for Google Drive.
|
||||
|
||||
For Google Drive files, uses file_id as the unique identifier (doesn't change on rename).
|
||||
For other files, uses filename.
|
||||
|
||||
Args:
|
||||
connector: Optional connector info dict with type and metadata
|
||||
filename: The filename (used for non-Google Drive files or as fallback)
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
Tuple of (primary_hash, legacy_hash or None)
|
||||
- For Google Drive: (file_id_based_hash, filename_based_hash for migration)
|
||||
- For other sources: (filename_based_hash, None)
|
||||
"""
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
metadata = connector.get("metadata", {})
|
||||
file_id = metadata.get("google_drive_file_id")
|
||||
|
||||
if file_id:
|
||||
# New method: use file_id as unique identifier (doesn't change on rename)
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
# Legacy method: for backward compatibility with existing documents
|
||||
# that were indexed with filename-based hash
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, legacy_hash
|
||||
|
||||
# For non-Google Drive files, use filename as before
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, None
|
||||
|
||||
|
||||
async def handle_existing_document_update(
|
||||
session: AsyncSession,
|
||||
existing_document: Document,
|
||||
content_hash: str,
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
primary_hash: str,
|
||||
) -> tuple[bool, Document | None]:
|
||||
"""
|
||||
Handle update logic for an existing document.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
existing_document: The existing document found in database
|
||||
content_hash: Hash of the new content
|
||||
connector: Optional connector info
|
||||
filename: Current filename
|
||||
primary_hash: The primary hash (file_id based for Google Drive)
|
||||
|
||||
Returns:
|
||||
Tuple of (should_skip_processing, document_to_return)
|
||||
- (True, document): Content unchanged, just return existing document
|
||||
- (False, None): Content changed, need to re-process
|
||||
"""
|
||||
# Check if this document needs hash migration (found via legacy hash)
|
||||
if existing_document.unique_identifier_hash != primary_hash:
|
||||
existing_document.unique_identifier_hash = primary_hash
|
||||
logging.info(f"Migrated document to file_id-based identifier: {filename}")
|
||||
|
||||
# Check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Content unchanged - check if we need to update metadata (e.g., filename changed)
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
connector_metadata = connector.get("metadata", {})
|
||||
new_name = connector_metadata.get("google_drive_file_name")
|
||||
# Check both possible keys for old name (FILE_NAME is used in stored documents)
|
||||
doc_metadata = existing_document.document_metadata or {}
|
||||
old_name = doc_metadata.get("FILE_NAME") or doc_metadata.get(
|
||||
"google_drive_file_name"
|
||||
)
|
||||
|
||||
if new_name and old_name and old_name != new_name:
|
||||
# File was renamed - update title and metadata, skip expensive processing
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
existing_document.title = new_name
|
||||
if not existing_document.document_metadata:
|
||||
existing_document.document_metadata = {}
|
||||
existing_document.document_metadata["FILE_NAME"] = new_name
|
||||
existing_document.document_metadata["google_drive_file_name"] = new_name
|
||||
flag_modified(existing_document, "document_metadata")
|
||||
await session.commit()
|
||||
logging.info(
|
||||
f"File renamed in Google Drive: '{old_name}' → '{new_name}' (no re-processing needed)"
|
||||
)
|
||||
|
||||
logging.info(f"Document for file {filename} unchanged. Skipping.")
|
||||
return True, existing_document
|
||||
else:
|
||||
# Content has changed - need to re-process
|
||||
logging.info(f"Content changed for file {filename}. Updating document.")
|
||||
return False, None
|
||||
|
||||
|
||||
async def find_existing_document_with_migration(
|
||||
session: AsyncSession,
|
||||
primary_hash: str,
|
||||
legacy_hash: str | None,
|
||||
content_hash: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Find existing document, checking both new hash and legacy hash for migration,
|
||||
with fallback to content_hash for cross-source deduplication.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
primary_hash: The primary hash (file_id based for Google Drive)
|
||||
legacy_hash: The legacy hash (filename based) for migration, or None
|
||||
content_hash: The content hash for fallback deduplication, or None
|
||||
|
||||
Returns:
|
||||
Existing document if found, None otherwise
|
||||
"""
|
||||
# First check with primary hash (new method)
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
# If not found and we have a legacy hash, check with that (migration path)
|
||||
if not existing_document and legacy_hash:
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
"Found legacy document (filename-based hash), will migrate to file_id-based hash"
|
||||
)
|
||||
|
||||
# Fallback: check by content_hash to catch duplicates from different sources
|
||||
# This prevents unique constraint violations when the same content exists
|
||||
# under a different unique_identifier (e.g., manual upload vs Google Drive)
|
||||
if not existing_document and content_hash:
|
||||
existing_document = await check_duplicate_document(session, content_hash)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
f"Found duplicate content from different source (content_hash match). "
|
||||
f"Original document ID: {existing_document.id}, type: {existing_document.document_type}"
|
||||
)
|
||||
|
||||
return existing_document
|
||||
|
||||
|
||||
async def parse_with_llamacloud_retry(
|
||||
file_path: str,
|
||||
estimated_pages: int,
|
||||
|
|
@ -157,6 +313,7 @@ async def add_received_file_document_using_unstructured(
|
|||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store a file document using Unstructured service.
|
||||
|
|
@ -167,6 +324,7 @@ async def add_received_file_document_using_unstructured(
|
|||
unstructured_processed_elements: Processed elements from Unstructured
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
connector: Optional connector info for Google Drive files
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
|
|
@ -176,29 +334,32 @@ async def add_received_file_document_using_unstructured(
|
|||
unstructured_processed_elements
|
||||
)
|
||||
|
||||
# Generate unique identifier hash for this file
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file_name, search_space_id
|
||||
# Generate unique identifier hash (uses file_id for Google Drive, filename for others)
|
||||
primary_hash, legacy_hash = get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
# Check if document exists (with migration support for Google Drive and content_hash fallback)
|
||||
existing_document = await find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
logging.info(f"Document for file {file_name} unchanged. Skipping.")
|
||||
return existing_document
|
||||
else:
|
||||
# Content has changed - update the existing document
|
||||
logging.info(
|
||||
f"Content changed for file {file_name}. Updating document."
|
||||
)
|
||||
# Handle existing document (rename detection, content change check)
|
||||
should_skip, doc = await handle_existing_document_update(
|
||||
session,
|
||||
existing_document,
|
||||
content_hash,
|
||||
connector,
|
||||
file_name,
|
||||
primary_hash,
|
||||
)
|
||||
if should_skip:
|
||||
return doc
|
||||
# Content changed - continue to update
|
||||
|
||||
# Get user's long context LLM (needed for both create and update)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
|
@ -250,10 +411,15 @@ async def add_received_file_document_using_unstructured(
|
|||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_type=doc_type,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"ETL_SERVICE": "UNSTRUCTURED",
|
||||
|
|
@ -262,7 +428,7 @@ async def add_received_file_document_using_unstructured(
|
|||
embedding=summary_embedding,
|
||||
chunks=chunks,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=primary_hash,
|
||||
blocknote_document=blocknote_json,
|
||||
content_needs_reindexing=False,
|
||||
updated_at=get_current_timestamp(),
|
||||
|
|
@ -287,6 +453,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by LlamaCloud.
|
||||
|
|
@ -297,6 +464,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
llamacloud_markdown_document: Markdown content from LlamaCloud parsing
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
connector: Optional connector info for Google Drive files
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
|
|
@ -305,29 +473,32 @@ async def add_received_file_document_using_llamacloud(
|
|||
# Combine all markdown documents into one
|
||||
file_in_markdown = llamacloud_markdown_document
|
||||
|
||||
# Generate unique identifier hash for this file
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file_name, search_space_id
|
||||
# Generate unique identifier hash (uses file_id for Google Drive, filename for others)
|
||||
primary_hash, legacy_hash = get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
# Check if document exists (with migration support for Google Drive and content_hash fallback)
|
||||
existing_document = await find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
logging.info(f"Document for file {file_name} unchanged. Skipping.")
|
||||
return existing_document
|
||||
else:
|
||||
# Content has changed - update the existing document
|
||||
logging.info(
|
||||
f"Content changed for file {file_name}. Updating document."
|
||||
)
|
||||
# Handle existing document (rename detection, content change check)
|
||||
should_skip, doc = await handle_existing_document_update(
|
||||
session,
|
||||
existing_document,
|
||||
content_hash,
|
||||
connector,
|
||||
file_name,
|
||||
primary_hash,
|
||||
)
|
||||
if should_skip:
|
||||
return doc
|
||||
# Content changed - continue to update
|
||||
|
||||
# Get user's long context LLM (needed for both create and update)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
|
@ -379,10 +550,15 @@ async def add_received_file_document_using_llamacloud(
|
|||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_type=doc_type,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"ETL_SERVICE": "LLAMACLOUD",
|
||||
|
|
@ -391,7 +567,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
embedding=summary_embedding,
|
||||
chunks=chunks,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=primary_hash,
|
||||
blocknote_document=blocknote_json,
|
||||
content_needs_reindexing=False,
|
||||
updated_at=get_current_timestamp(),
|
||||
|
|
@ -418,6 +594,7 @@ async def add_received_file_document_using_docling(
|
|||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by Docling.
|
||||
|
|
@ -428,6 +605,7 @@ async def add_received_file_document_using_docling(
|
|||
docling_markdown_document: Markdown content from Docling parsing
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
connector: Optional connector info for Google Drive files
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
|
|
@ -435,35 +613,38 @@ async def add_received_file_document_using_docling(
|
|||
try:
|
||||
file_in_markdown = docling_markdown_document
|
||||
|
||||
# Generate unique identifier hash for this file
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file_name, search_space_id
|
||||
# Generate unique identifier hash (uses file_id for Google Drive, filename for others)
|
||||
primary_hash, legacy_hash = get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
# Check if document exists (with migration support for Google Drive and content_hash fallback)
|
||||
existing_document = await find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
logging.info(f"Document for file {file_name} unchanged. Skipping.")
|
||||
return existing_document
|
||||
else:
|
||||
# Content has changed - update the existing document
|
||||
logging.info(
|
||||
f"Content changed for file {file_name}. Updating document."
|
||||
)
|
||||
# Handle existing document (rename detection, content change check)
|
||||
should_skip, doc = await handle_existing_document_update(
|
||||
session,
|
||||
existing_document,
|
||||
content_hash,
|
||||
connector,
|
||||
file_name,
|
||||
primary_hash,
|
||||
)
|
||||
if should_skip:
|
||||
return doc
|
||||
# Content changed - continue to update
|
||||
|
||||
# Get user's long context LLM (needed for both create and update)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
f"No long context LLM configured for user {user_id} in search_space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary using chunked processing for large documents
|
||||
|
|
@ -533,10 +714,15 @@ async def add_received_file_document_using_docling(
|
|||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_type=doc_type,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"ETL_SERVICE": "DOCLING",
|
||||
|
|
@ -545,15 +731,15 @@ async def add_received_file_document_using_docling(
|
|||
embedding=summary_embedding,
|
||||
chunks=chunks,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=primary_hash,
|
||||
blocknote_document=blocknote_json,
|
||||
content_needs_reindexing=False,
|
||||
updated_at=get_current_timestamp(),
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
|
|
@ -594,10 +780,23 @@ async def process_file_in_background(
|
|||
log_entry: Log,
|
||||
connector: dict
|
||||
| None = None, # Optional: {"type": "GOOGLE_DRIVE_FILE", "metadata": {...}}
|
||||
):
|
||||
notification: Notification
|
||||
| None = None, # Optional notification for progress updates
|
||||
) -> Document | None:
|
||||
try:
|
||||
# Check if the file is a markdown or text file
|
||||
if filename.lower().endswith((".md", ".markdown", ".txt")):
|
||||
# Update notification: parsing stage
|
||||
if notification:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Reading file",
|
||||
)
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing markdown/text file: {filename}",
|
||||
|
|
@ -617,6 +816,14 @@ async def process_file_in_background(
|
|||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Update notification: chunking stage
|
||||
if notification:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
)
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating document from markdown content: {filename}",
|
||||
|
|
@ -628,7 +835,7 @@ async def process_file_in_background(
|
|||
|
||||
# Process markdown directly through specialized function
|
||||
result = await add_received_markdown_file_document(
|
||||
session, filename, markdown_content, search_space_id, user_id
|
||||
session, filename, markdown_content, search_space_id, user_id, connector
|
||||
)
|
||||
|
||||
if connector:
|
||||
|
|
@ -644,17 +851,30 @@ async def process_file_in_background(
|
|||
"file_type": "markdown",
|
||||
},
|
||||
)
|
||||
return result
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "markdown"},
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if the file is an audio file
|
||||
elif filename.lower().endswith(
|
||||
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
):
|
||||
# Update notification: parsing stage (transcription)
|
||||
if notification:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Transcribing audio",
|
||||
)
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing audio file for transcription: {filename}",
|
||||
|
|
@ -738,6 +958,14 @@ async def process_file_in_background(
|
|||
},
|
||||
)
|
||||
|
||||
# Update notification: chunking stage
|
||||
if notification:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up the temp file
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
|
|
@ -747,7 +975,7 @@ async def process_file_in_background(
|
|||
|
||||
# Process transcription as markdown document
|
||||
result = await add_received_markdown_file_document(
|
||||
session, filename, transcribed_text, search_space_id, user_id
|
||||
session, filename, transcribed_text, search_space_id, user_id, connector
|
||||
)
|
||||
|
||||
if connector:
|
||||
|
|
@ -765,12 +993,14 @@ async def process_file_in_background(
|
|||
"stt_service": stt_service_type,
|
||||
},
|
||||
)
|
||||
return result
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Audio file transcript already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "audio"},
|
||||
)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Import page limit service
|
||||
|
|
@ -835,6 +1065,15 @@ async def process_file_in_background(
|
|||
) from e
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
# Update notification: parsing stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Extracting content",
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Unstructured ETL: {filename}",
|
||||
|
|
@ -860,6 +1099,12 @@ async def process_file_in_background(
|
|||
|
||||
docs = await loader.aload()
|
||||
|
||||
# Update notification: chunking stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking", chunks_count=len(docs)
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Unstructured ETL completed, creating document: {filename}",
|
||||
|
|
@ -895,7 +1140,7 @@ async def process_file_in_background(
|
|||
|
||||
# Pass the documents to the existing background task
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
session, filename, docs, search_space_id, user_id
|
||||
session, filename, docs, search_space_id, user_id, connector
|
||||
)
|
||||
|
||||
if connector:
|
||||
|
|
@ -919,6 +1164,7 @@ async def process_file_in_background(
|
|||
"pages_processed": final_page_count,
|
||||
},
|
||||
)
|
||||
return result
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -929,8 +1175,18 @@ async def process_file_in_background(
|
|||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
# Update notification: parsing stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Extracting content",
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with LlamaCloud ETL: {filename}",
|
||||
|
|
@ -964,6 +1220,15 @@ async def process_file_in_background(
|
|||
split_by_page=False
|
||||
)
|
||||
|
||||
# Update notification: chunking stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="chunking",
|
||||
chunks_count=len(markdown_documents),
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud parsing completed, creating documents: {filename}",
|
||||
|
|
@ -1023,6 +1288,7 @@ async def process_file_in_background(
|
|||
llamacloud_markdown_document=markdown_content,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
# Track if this document was successfully created
|
||||
|
|
@ -1056,6 +1322,7 @@ async def process_file_in_background(
|
|||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return last_created_doc
|
||||
else:
|
||||
# All documents were duplicates (markdown_documents was not empty, but all returned None)
|
||||
await task_logger.log_task_success(
|
||||
|
|
@ -1068,8 +1335,18 @@ async def process_file_in_background(
|
|||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
elif app_config.ETL_SERVICE == "DOCLING":
|
||||
# Update notification: parsing stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Extracting content",
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Docling ETL: {filename}",
|
||||
|
|
@ -1152,6 +1429,12 @@ async def process_file_in_background(
|
|||
},
|
||||
)
|
||||
|
||||
# Update notification: chunking stage
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
)
|
||||
|
||||
# Process the document using our Docling background task
|
||||
doc_result = await add_received_file_document_using_docling(
|
||||
session,
|
||||
|
|
@ -1159,6 +1442,7 @@ async def process_file_in_background(
|
|||
docling_markdown_document=result["content"],
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
if doc_result:
|
||||
|
|
@ -1184,6 +1468,7 @@ async def process_file_in_background(
|
|||
"pages_processed": final_page_count,
|
||||
},
|
||||
)
|
||||
return doc_result
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -1194,6 +1479,7 @@ async def process_file_in_background(
|
|||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
||||
|
|
|
|||
|
|
@ -19,16 +19,157 @@ from app.utils.document_converters import (
|
|||
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
)
|
||||
|
||||
|
||||
def _get_google_drive_unique_identifier(
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
) -> tuple[str, str | None]:
|
||||
"""
|
||||
Get unique identifier hash for a file, with special handling for Google Drive.
|
||||
|
||||
For Google Drive files, uses file_id as the unique identifier (doesn't change on rename).
|
||||
For other files, uses filename.
|
||||
|
||||
Args:
|
||||
connector: Optional connector info dict with type and metadata
|
||||
filename: The filename (used for non-Google Drive files or as fallback)
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
Tuple of (primary_hash, legacy_hash or None)
|
||||
"""
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
metadata = connector.get("metadata", {})
|
||||
file_id = metadata.get("google_drive_file_id")
|
||||
|
||||
if file_id:
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, legacy_hash
|
||||
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, None
|
||||
|
||||
|
||||
async def _find_existing_document_with_migration(
|
||||
session: AsyncSession,
|
||||
primary_hash: str,
|
||||
legacy_hash: str | None,
|
||||
content_hash: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Find existing document, checking both new hash and legacy hash for migration,
|
||||
with fallback to content_hash for cross-source deduplication.
|
||||
"""
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing_document and legacy_hash:
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
"Found legacy document (filename-based hash), will migrate to file_id-based hash"
|
||||
)
|
||||
|
||||
# Fallback: check by content_hash to catch duplicates from different sources
|
||||
if not existing_document and content_hash:
|
||||
existing_document = await check_duplicate_document(session, content_hash)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
f"Found duplicate content from different source (content_hash match). "
|
||||
f"Original document ID: {existing_document.id}, type: {existing_document.document_type}"
|
||||
)
|
||||
|
||||
return existing_document
|
||||
|
||||
|
||||
async def _handle_existing_document_update(
|
||||
session: AsyncSession,
|
||||
existing_document: Document,
|
||||
content_hash: str,
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
primary_hash: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
) -> tuple[bool, Document | None]:
|
||||
"""
|
||||
Handle update logic for an existing document.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_skip_processing, document_to_return)
|
||||
"""
|
||||
# Check if this document needs hash migration
|
||||
if existing_document.unique_identifier_hash != primary_hash:
|
||||
existing_document.unique_identifier_hash = primary_hash
|
||||
logging.info(f"Migrated document to file_id-based identifier: {filename}")
|
||||
|
||||
# Check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Content unchanged - check if we need to update metadata (e.g., filename changed)
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
connector_metadata = connector.get("metadata", {})
|
||||
new_name = connector_metadata.get("google_drive_file_name")
|
||||
# Check both possible keys for old name (FILE_NAME is used in stored documents)
|
||||
doc_metadata = existing_document.document_metadata or {}
|
||||
old_name = (
|
||||
doc_metadata.get("FILE_NAME")
|
||||
or doc_metadata.get("google_drive_file_name")
|
||||
or doc_metadata.get("file_name")
|
||||
)
|
||||
|
||||
if new_name and old_name and old_name != new_name:
|
||||
# File was renamed - update title and metadata, skip expensive processing
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
existing_document.title = new_name
|
||||
if not existing_document.document_metadata:
|
||||
existing_document.document_metadata = {}
|
||||
existing_document.document_metadata["FILE_NAME"] = new_name
|
||||
existing_document.document_metadata["file_name"] = new_name
|
||||
existing_document.document_metadata["google_drive_file_name"] = new_name
|
||||
flag_modified(existing_document, "document_metadata")
|
||||
await session.commit()
|
||||
logging.info(
|
||||
f"File renamed in Google Drive: '{old_name}' → '{new_name}' (no re-processing needed)"
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file document unchanged: {filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(f"Document for markdown file {filename} unchanged. Skipping.")
|
||||
return True, existing_document
|
||||
else:
|
||||
logging.info(
|
||||
f"Content changed for markdown file {filename}. Updating document."
|
||||
)
|
||||
return False, None
|
||||
|
||||
|
||||
async def add_received_markdown_file_document(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
file_in_markdown: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store a markdown file document.
|
||||
|
|
@ -39,6 +180,7 @@ async def add_received_markdown_file_document(
|
|||
file_in_markdown: Content of the markdown file
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
connector: Optional connector info for Google Drive files
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
|
|
@ -58,39 +200,34 @@ async def add_received_markdown_file_document(
|
|||
)
|
||||
|
||||
try:
|
||||
# Generate unique identifier hash for this markdown file
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file_name, search_space_id
|
||||
# Generate unique identifier hash (uses file_id for Google Drive, filename for others)
|
||||
primary_hash, legacy_hash = _get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
# Check if document with this unique identifier already exists
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
# Check if document exists (with migration support for Google Drive and content_hash fallback)
|
||||
existing_document = await _find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file document unchanged: {file_name}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document for markdown file {file_name} unchanged. Skipping."
|
||||
)
|
||||
return existing_document
|
||||
else:
|
||||
# Content has changed - update the existing document
|
||||
logging.info(
|
||||
f"Content changed for markdown file {file_name}. Updating document."
|
||||
)
|
||||
# Handle existing document (rename detection, content change check)
|
||||
should_skip, doc = await _handle_existing_document_update(
|
||||
session,
|
||||
existing_document,
|
||||
content_hash,
|
||||
connector,
|
||||
file_name,
|
||||
primary_hash,
|
||||
task_logger,
|
||||
log_entry,
|
||||
)
|
||||
if should_skip:
|
||||
return doc
|
||||
# Content changed - continue to update
|
||||
|
||||
# Get user's long context LLM (needed for both create and update)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
|
@ -139,10 +276,15 @@ async def add_received_markdown_file_document(
|
|||
document = existing_document
|
||||
else:
|
||||
# Create new document
|
||||
# Determine document type based on connector
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_type=doc_type,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
},
|
||||
|
|
@ -150,7 +292,7 @@ async def add_received_markdown_file_document(
|
|||
embedding=summary_embedding,
|
||||
chunks=chunks,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=primary_hash,
|
||||
blocknote_document=blocknote_json,
|
||||
updated_at=get_current_timestamp(),
|
||||
)
|
||||
|
|
|
|||
64
surfsense_backend/app/utils/chat_comments.py
Normal file
64
surfsense_backend/app/utils/chat_comments.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""
|
||||
Utility functions for chat comments, including mention parsing.
|
||||
"""
|
||||
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
# Pattern to match @[uuid] mentions in comment content
|
||||
MENTION_PATTERN = re.compile(r"@\[([0-9a-fA-F-]{36})\]")
|
||||
|
||||
|
||||
def parse_mentions(content: str) -> list[UUID]:
|
||||
"""
|
||||
Extract user UUIDs from @[uuid] mentions in content.
|
||||
|
||||
Args:
|
||||
content: Comment text that may contain @[uuid] mentions
|
||||
|
||||
Returns:
|
||||
List of unique user UUIDs found in the content
|
||||
"""
|
||||
matches = MENTION_PATTERN.findall(content)
|
||||
unique_uuids = []
|
||||
seen = set()
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
uuid = UUID(match)
|
||||
if uuid not in seen:
|
||||
seen.add(uuid)
|
||||
unique_uuids.append(uuid)
|
||||
except ValueError:
|
||||
# Invalid UUID format, skip
|
||||
continue
|
||||
|
||||
return unique_uuids
|
||||
|
||||
|
||||
def render_mentions(content: str, user_names: dict[UUID, str]) -> str:
|
||||
"""
|
||||
Replace @[uuid] mentions with @{DisplayName} in content.
|
||||
|
||||
Uses curly braces as delimiters for unambiguous frontend parsing.
|
||||
|
||||
Args:
|
||||
content: Comment text with @[uuid] mentions
|
||||
user_names: Dict mapping user UUIDs to display names
|
||||
|
||||
Returns:
|
||||
Content with mentions rendered as @{DisplayName}
|
||||
"""
|
||||
|
||||
def replace_mention(match: re.Match) -> str:
|
||||
try:
|
||||
uuid = UUID(match.group(1))
|
||||
name = user_names.get(uuid)
|
||||
if name:
|
||||
return f"@{{{name}}}"
|
||||
# Keep original format if user not found
|
||||
return match.group(0)
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
|
||||
return MENTION_PATTERN.sub(replace_mention, content)
|
||||
|
|
@ -11,6 +11,26 @@ cleanup() {
|
|||
|
||||
trap cleanup SIGTERM SIGINT
|
||||
|
||||
# Run database migrations with safeguards
|
||||
echo "Running database migrations..."
|
||||
# Wait for database to be ready (max 30 seconds)
|
||||
for i in {1..30}; do
|
||||
if python -c "from app.db import engine; import asyncio; asyncio.run(engine.dispose())" 2>/dev/null; then
|
||||
echo "Database is ready."
|
||||
break
|
||||
fi
|
||||
echo "Waiting for database... ($i/30)"
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Run migrations with timeout (60 seconds max)
|
||||
if timeout 60 alembic upgrade head 2>&1; then
|
||||
echo "Migrations completed successfully."
|
||||
else
|
||||
echo "WARNING: Migration failed or timed out. Continuing anyway..."
|
||||
echo "You may need to run migrations manually: alembic upgrade head"
|
||||
fi
|
||||
|
||||
echo "Starting FastAPI Backend..."
|
||||
python main.py &
|
||||
backend_pid=$!
|
||||
|
|
|
|||
6035
surfsense_backend/uv.lock
generated
6035
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,5 +1,10 @@
|
|||
NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE
|
||||
NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING
|
||||
|
||||
# Electric SQL
|
||||
NEXT_PUBLIC_ELECTRIC_URL=http://localhost:5133
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# Contact Form Vars - OPTIONAL
|
||||
DATABASE_URL=postgresql://postgres:[YOUR-PASSWORD]@db.sdsf.supabase.co:5432/postgres
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { ChevronDown, ChevronUp, FileX, Plus } from "lucide-react";
|
||||
import { ChevronDown, ChevronUp, FileX, Loader2, Plus } from "lucide-react";
|
||||
import { motion } from "motion/react";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
|
|
@ -114,7 +114,7 @@ export function DocumentsTableShell({
|
|||
{loading ? (
|
||||
<div className="flex h-[400px] w-full items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<div className="h-8 w-8 animate-spin rounded-full border-b-2 border-primary"></div>
|
||||
<Loader2 className="h-8 w-8 animate-spin text-primary" />
|
||||
<p className="text-sm text-muted-foreground">{t("loading")}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,44 +0,0 @@
|
|||
"use client";
|
||||
|
||||
import { Loader2 } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
|
||||
interface ProcessingIndicatorProps {
|
||||
documentProcessorTasksCount: number;
|
||||
}
|
||||
|
||||
export function ProcessingIndicator({ documentProcessorTasksCount }: ProcessingIndicatorProps) {
|
||||
const t = useTranslations("documents");
|
||||
|
||||
// Only show when there are document_processor tasks (uploads), not connector_indexing_task (periodic reindexing)
|
||||
if (documentProcessorTasksCount === 0) return null;
|
||||
|
||||
return (
|
||||
<AnimatePresence>
|
||||
<motion.div
|
||||
initial={{ opacity: 0, height: 0, marginBottom: 0 }}
|
||||
animate={{ opacity: 1, height: "auto", marginBottom: 24 }}
|
||||
exit={{ opacity: 0, height: 0, marginBottom: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<Alert className="border-border bg-primary/5">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-primary/10">
|
||||
<Loader2 className="h-5 w-5 animate-spin text-primary" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<AlertTitle className="text-primary font-semibold">
|
||||
{t("processing_documents")}
|
||||
</AlertTitle>
|
||||
<AlertDescription className="text-muted-foreground">
|
||||
{t("active_tasks_count", { count: documentProcessorTasksCount })}
|
||||
</AlertDescription>
|
||||
</div>
|
||||
</div>
|
||||
</Alert>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
);
|
||||
}
|
||||
|
|
@ -6,20 +6,18 @@ import { RefreshCw, SquarePlus, Upload } from "lucide-react";
|
|||
import { motion } from "motion/react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useCallback, useEffect, useId, useMemo, useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useId, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { useLogsSummary } from "@/hooks/use-logs";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { DocumentsFilters } from "./components/DocumentsFilters";
|
||||
import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell";
|
||||
import { PaginationControls } from "./components/PaginationControls";
|
||||
import { ProcessingIndicator } from "./components/ProcessingIndicator";
|
||||
import type { ColumnVisibility } from "./components/types";
|
||||
|
||||
function useDebounced<T>(value: T, delay = 250) {
|
||||
|
|
@ -109,6 +107,52 @@ export default function DocumentsTable() {
|
|||
enabled: !!searchSpaceId && !!debouncedSearch.trim(),
|
||||
});
|
||||
|
||||
// Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected)
|
||||
const showSurfsenseDocs =
|
||||
activeTypes.length === 0 || activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum);
|
||||
|
||||
// Use query for fetching SurfSense docs
|
||||
const {
|
||||
data: surfsenseDocsResponse,
|
||||
isLoading: isSurfsenseDocsLoading,
|
||||
refetch: refetchSurfsenseDocs,
|
||||
} = useQuery({
|
||||
queryKey: ["surfsense-docs", debouncedSearch, pageIndex, pageSize],
|
||||
queryFn: () =>
|
||||
documentsApiService.getSurfsenseDocs({
|
||||
queryParams: {
|
||||
page: pageIndex,
|
||||
page_size: pageSize,
|
||||
title: debouncedSearch.trim() || undefined,
|
||||
},
|
||||
}),
|
||||
staleTime: 3 * 60 * 1000, // 3 minutes
|
||||
enabled: showSurfsenseDocs,
|
||||
});
|
||||
|
||||
// Transform SurfSense docs to match the Document type
|
||||
const surfsenseDocsAsDocuments: Document[] = useMemo(() => {
|
||||
if (!surfsenseDocsResponse?.items) return [];
|
||||
return surfsenseDocsResponse.items.map((doc) => ({
|
||||
id: doc.id,
|
||||
title: doc.title,
|
||||
document_type: "SURFSENSE_DOCS",
|
||||
document_metadata: { source: doc.source },
|
||||
content: doc.content,
|
||||
created_at: new Date().toISOString(),
|
||||
search_space_id: -1, // Special value for global docs
|
||||
}));
|
||||
}, [surfsenseDocsResponse]);
|
||||
|
||||
// Merge type counts with SURFSENSE_DOCS count
|
||||
const typeCounts = useMemo(() => {
|
||||
const counts = { ...(rawTypeCounts || {}) };
|
||||
if (surfsenseDocsResponse?.total) {
|
||||
counts.SURFSENSE_DOCS = surfsenseDocsResponse.total;
|
||||
}
|
||||
return counts;
|
||||
}, [rawTypeCounts, surfsenseDocsResponse?.total]);
|
||||
|
||||
// Extract documents and total based on search state
|
||||
const documents = debouncedSearch.trim()
|
||||
? searchResponse?.items || []
|
||||
|
|
@ -150,30 +194,6 @@ export default function DocumentsTable() {
|
|||
}
|
||||
}, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]);
|
||||
|
||||
// Set up smart polling for active tasks - only polls when tasks are in progress
|
||||
const { summary } = useLogsSummary(searchSpaceId, 24, {
|
||||
enablePolling: true,
|
||||
refetchInterval: 5000, // Poll every 5 seconds when tasks are active
|
||||
});
|
||||
|
||||
// Filter active tasks to only include document_processor tasks (uploads via "add sources")
|
||||
// Exclude connector_indexing_task tasks (periodic reindexing)
|
||||
const documentProcessorTasks =
|
||||
summary?.active_tasks.filter((task) => task.source === "document_processor") || [];
|
||||
const documentProcessorTasksCount = documentProcessorTasks.length;
|
||||
|
||||
const activeTasksCount = summary?.active_tasks.length || 0;
|
||||
const prevActiveTasksCount = useRef(activeTasksCount);
|
||||
|
||||
// Auto-refresh when a task finishes
|
||||
useEffect(() => {
|
||||
if (prevActiveTasksCount.current > activeTasksCount) {
|
||||
// A task has finished!
|
||||
refreshCurrentView();
|
||||
}
|
||||
prevActiveTasksCount.current = activeTasksCount;
|
||||
}, [activeTasksCount, refreshCurrentView]);
|
||||
|
||||
// Create a delete function for single document deletion
|
||||
const deleteDocument = useCallback(
|
||||
async (id: number) => {
|
||||
|
|
@ -262,8 +282,6 @@ export default function DocumentsTable() {
|
|||
</div>
|
||||
</motion.div>
|
||||
|
||||
<ProcessingIndicator documentProcessorTasksCount={documentProcessorTasksCount} />
|
||||
|
||||
<DocumentsFilters
|
||||
typeCounts={rawTypeCounts ?? {}}
|
||||
selectedIds={selectedIds}
|
||||
|
|
|
|||
|
|
@ -438,9 +438,7 @@ export default function EditorPage() {
|
|||
{saving ? (
|
||||
<>
|
||||
<Loader2 className="h-3.5 w-3.5 md:h-4 md:w-4 animate-spin" />
|
||||
<span className="text-xs md:text-sm">
|
||||
{isNewNote ? "Creating..." : "Saving..."}
|
||||
</span>
|
||||
<span className="text-xs md:text-sm">{isNewNote ? "Creating" : "Saving"}</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ import {
|
|||
} from "@assistant-ui/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue, useSetAtom } from "jotai";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useParams, useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { z } from "zod";
|
||||
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||
import {
|
||||
type MentionedDocumentInfo,
|
||||
mentionedDocumentIdsAtom,
|
||||
|
|
@ -251,6 +252,7 @@ export default function NewChatPage() {
|
|||
const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom);
|
||||
const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
|
||||
const hydratePlanState = useSetAtom(hydratePlanStateAtom);
|
||||
const setCurrentThreadState = useSetAtom(currentThreadAtom);
|
||||
|
||||
// Get current user for author info in shared chats
|
||||
const { data: currentUser } = useAtomValue(currentUserAtom);
|
||||
|
|
@ -365,6 +367,48 @@ export default function NewChatPage() {
|
|||
initializeThread();
|
||||
}, [initializeThread]);
|
||||
|
||||
// Handle scroll to comment from URL query params (e.g., from notification click)
|
||||
const searchParams = useSearchParams();
|
||||
const targetCommentId = searchParams.get("commentId");
|
||||
|
||||
useEffect(() => {
|
||||
if (!targetCommentId || isInitializing || messages.length === 0) return;
|
||||
|
||||
const tryScroll = () => {
|
||||
const el = document.querySelector(`[data-comment-id="${targetCommentId}"]`);
|
||||
if (el) {
|
||||
el.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Try immediately
|
||||
if (tryScroll()) return;
|
||||
|
||||
// Retry every 200ms for up to 10 seconds
|
||||
const intervalId = setInterval(() => {
|
||||
if (tryScroll()) clearInterval(intervalId);
|
||||
}, 200);
|
||||
|
||||
const timeoutId = setTimeout(() => clearInterval(intervalId), 10000);
|
||||
|
||||
return () => {
|
||||
clearInterval(intervalId);
|
||||
clearTimeout(timeoutId);
|
||||
};
|
||||
}, [targetCommentId, isInitializing, messages.length]);
|
||||
|
||||
// Sync current thread state to atom
|
||||
useEffect(() => {
|
||||
setCurrentThreadState({
|
||||
id: currentThread?.id ?? null,
|
||||
visibility: currentThread?.visibility ?? null,
|
||||
hasComments: currentThread?.has_comments ?? false,
|
||||
addingCommentToMessageId: null,
|
||||
});
|
||||
}, [currentThread, setCurrentThreadState]);
|
||||
|
||||
// Cancel ongoing request
|
||||
const cancelRun = useCallback(async () => {
|
||||
if (abortControllerRef.current) {
|
||||
|
|
@ -842,10 +886,32 @@ export default function NewChatPage() {
|
|||
// Persist assistant message (with thinking steps for restoration on refresh)
|
||||
const finalContent = buildContentForPersistence();
|
||||
if (contentParts.length > 0) {
|
||||
appendMessage(currentThreadId, {
|
||||
role: "assistant",
|
||||
content: finalContent,
|
||||
}).catch((err) => console.error("Failed to persist assistant message:", err));
|
||||
try {
|
||||
const savedMessage = await appendMessage(currentThreadId, {
|
||||
role: "assistant",
|
||||
content: finalContent,
|
||||
});
|
||||
|
||||
// Update message ID from temporary to database ID so comments work immediately
|
||||
const newMsgId = `msg-${savedMessage.id}`;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
||||
);
|
||||
|
||||
// Also update thinking steps map with new ID
|
||||
setMessageThinkingSteps((prev) => {
|
||||
const steps = prev.get(assistantMsgId);
|
||||
if (steps) {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(assistantMsgId);
|
||||
newMap.set(newMsgId, steps);
|
||||
return newMap;
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
} catch (err) {
|
||||
console.error("Failed to persist assistant message:", err);
|
||||
}
|
||||
|
||||
// Track successful response
|
||||
trackChatResponseReceived(searchSpaceId, currentThreadId);
|
||||
|
|
@ -860,10 +926,20 @@ export default function NewChatPage() {
|
|||
);
|
||||
if (hasContent && currentThreadId) {
|
||||
const partialContent = buildContentForPersistence();
|
||||
appendMessage(currentThreadId, {
|
||||
role: "assistant",
|
||||
content: partialContent,
|
||||
}).catch((err) => console.error("Failed to persist partial assistant message:", err));
|
||||
try {
|
||||
const savedMessage = await appendMessage(currentThreadId, {
|
||||
role: "assistant",
|
||||
content: partialContent,
|
||||
});
|
||||
|
||||
// Update message ID from temporary to database ID
|
||||
const newMsgId = `msg-${savedMessage.id}`;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to persist partial assistant message:", err);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1294,7 +1294,7 @@ function CreateInviteDialog({
|
|||
{creating ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
Creating
|
||||
</>
|
||||
) : (
|
||||
"Create Invite"
|
||||
|
|
@ -1471,7 +1471,7 @@ function CreateRoleDialog({
|
|||
{creating ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
Creating
|
||||
</>
|
||||
) : (
|
||||
"Create Role"
|
||||
|
|
|
|||
|
|
@ -157,5 +157,33 @@ button {
|
|||
cursor: pointer;
|
||||
}
|
||||
|
||||
/* Custom scrollbar styles */
|
||||
.scrollbar-thin {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: hsl(var(--muted-foreground) / 0.2) transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin:hover {
|
||||
scrollbar-color: hsl(var(--muted-foreground) / 0.4) transparent;
|
||||
}
|
||||
|
||||
/* Webkit scrollbar styles */
|
||||
.scrollbar-thin::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb {
|
||||
background-color: hsl(var(--muted-foreground) / 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
|
||||
background-color: hsl(var(--muted-foreground) / 0.4);
|
||||
}
|
||||
|
||||
@source '../node_modules/@llamaindex/chat-ui/**/*.{ts,tsx}';
|
||||
@source '../node_modules/streamdown/dist/*.js';
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import type { Metadata } from "next";
|
|||
import "./globals.css";
|
||||
import { RootProvider } from "fumadocs-ui/provider/next";
|
||||
import { Roboto } from "next/font/google";
|
||||
import { ElectricProvider } from "@/components/providers/ElectricProvider";
|
||||
import { I18nProvider } from "@/components/providers/I18nProvider";
|
||||
import { PostHogProvider } from "@/components/providers/PostHogProvider";
|
||||
import { ThemeProvider } from "@/components/theme/theme-provider";
|
||||
|
|
@ -102,7 +103,9 @@ export default function RootLayout({
|
|||
defaultTheme="light"
|
||||
>
|
||||
<RootProvider>
|
||||
<ReactQueryClientProvider>{children}</ReactQueryClientProvider>
|
||||
<ReactQueryClientProvider>
|
||||
<ElectricProvider>{children}</ElectricProvider>
|
||||
</ReactQueryClientProvider>
|
||||
<Toaster />
|
||||
</RootProvider>
|
||||
</ThemeProvider>
|
||||
|
|
|
|||
72
surfsense_web/atoms/chat-comments/comments-mutation.atoms.ts
Normal file
72
surfsense_web/atoms/chat-comments/comments-mutation.atoms.ts
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import { atomWithMutation } from "jotai-tanstack-query";
|
||||
import { toast } from "sonner";
|
||||
import type {
|
||||
CreateCommentRequest,
|
||||
CreateReplyRequest,
|
||||
DeleteCommentRequest,
|
||||
UpdateCommentRequest,
|
||||
} from "@/contracts/types/chat-comments.types";
|
||||
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { queryClient } from "@/lib/query-client/client";
|
||||
|
||||
export const createCommentMutationAtom = atomWithMutation(() => ({
|
||||
mutationFn: async (request: CreateCommentRequest) => {
|
||||
return chatCommentsApiService.createComment(request);
|
||||
},
|
||||
onSuccess: (_, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.comments.byMessage(variables.message_id),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Error creating comment:", error);
|
||||
toast.error("Failed to create comment");
|
||||
},
|
||||
}));
|
||||
|
||||
export const createReplyMutationAtom = atomWithMutation(() => ({
|
||||
mutationFn: async (request: CreateReplyRequest & { message_id: number }) => {
|
||||
return chatCommentsApiService.createReply(request);
|
||||
},
|
||||
onSuccess: (_, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.comments.byMessage(variables.message_id),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Error creating reply:", error);
|
||||
toast.error("Failed to create reply");
|
||||
},
|
||||
}));
|
||||
|
||||
export const updateCommentMutationAtom = atomWithMutation(() => ({
|
||||
mutationFn: async (request: UpdateCommentRequest & { message_id: number }) => {
|
||||
return chatCommentsApiService.updateComment(request);
|
||||
},
|
||||
onSuccess: (_, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.comments.byMessage(variables.message_id),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Error updating comment:", error);
|
||||
toast.error("Failed to update comment");
|
||||
},
|
||||
}));
|
||||
|
||||
export const deleteCommentMutationAtom = atomWithMutation(() => ({
|
||||
mutationFn: async (request: DeleteCommentRequest & { message_id: number }) => {
|
||||
return chatCommentsApiService.deleteComment(request);
|
||||
},
|
||||
onSuccess: (_, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.comments.byMessage(variables.message_id),
|
||||
});
|
||||
toast.success("Comment deleted");
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Error deleting comment:", error);
|
||||
toast.error("Failed to delete comment");
|
||||
},
|
||||
}));
|
||||
52
surfsense_web/atoms/chat/current-thread.atom.ts
Normal file
52
surfsense_web/atoms/chat/current-thread.atom.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import { atom } from "jotai";
|
||||
import type { ChatVisibility } from "@/lib/chat/thread-persistence";
|
||||
|
||||
// TODO: Update `hasComments` to true when the first comment is created on a thread.
|
||||
// Currently it only updates on thread load. The gutter still works because
|
||||
// `addingCommentToMessageId` keeps it open, but the state is technically stale.
|
||||
|
||||
// TODO: Reset `addingCommentToMessageId` to null after a comment is successfully created.
|
||||
// Currently it stays set until navigation or clicking another message's bubble.
|
||||
// Not causing issues since panel visibility is driven by per-message comment count.
|
||||
|
||||
// TODO: Consider calling `resetCurrentThreadAtom` when unmounting the chat page
|
||||
// for explicit cleanup, though React navigation handles this implicitly.
|
||||
|
||||
interface CurrentThreadState {
|
||||
id: number | null;
|
||||
visibility: ChatVisibility | null;
|
||||
hasComments: boolean;
|
||||
addingCommentToMessageId: number | null;
|
||||
}
|
||||
|
||||
const initialState: CurrentThreadState = {
|
||||
id: null,
|
||||
visibility: null,
|
||||
hasComments: false,
|
||||
addingCommentToMessageId: null,
|
||||
};
|
||||
|
||||
export const currentThreadAtom = atom<CurrentThreadState>(initialState);
|
||||
|
||||
export const commentsEnabledAtom = atom(
|
||||
(get) => get(currentThreadAtom).visibility === "SEARCH_SPACE"
|
||||
);
|
||||
|
||||
export const showCommentsGutterAtom = atom((get) => {
|
||||
const thread = get(currentThreadAtom);
|
||||
return (
|
||||
thread.visibility === "SEARCH_SPACE" &&
|
||||
(thread.hasComments || thread.addingCommentToMessageId !== null)
|
||||
);
|
||||
});
|
||||
|
||||
export const addingCommentToMessageIdAtom = atom(
|
||||
(get) => get(currentThreadAtom).addingCommentToMessageId,
|
||||
(get, set, messageId: number | null) => {
|
||||
set(currentThreadAtom, { ...get(currentThreadAtom), addingCommentToMessageId: messageId });
|
||||
}
|
||||
);
|
||||
|
||||
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
||||
set(currentThreadAtom, initialState);
|
||||
});
|
||||
|
|
@ -13,6 +13,7 @@ import {
|
|||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { cleanupElectric } from "@/lib/electric/client";
|
||||
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
||||
|
||||
export function UserDropdown({
|
||||
|
|
@ -26,12 +27,20 @@ export function UserDropdown({
|
|||
}) {
|
||||
const router = useRouter();
|
||||
|
||||
const handleLogout = () => {
|
||||
const handleLogout = async () => {
|
||||
try {
|
||||
// Track logout event and reset PostHog identity
|
||||
trackLogout();
|
||||
resetUser();
|
||||
|
||||
// Best-effort cleanup of Electric SQL / PGlite
|
||||
// Even if this fails, login-time cleanup will handle it
|
||||
try {
|
||||
await cleanupElectric();
|
||||
} catch (err) {
|
||||
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
|
||||
}
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.removeItem("surfsense_bearer_token");
|
||||
window.location.href = "/";
|
||||
|
|
@ -40,7 +49,7 @@ export function UserDropdown({
|
|||
console.error("Error during logout:", error);
|
||||
// Optionally, provide user feedback
|
||||
if (typeof window !== "undefined") {
|
||||
alert("Logout failed. Please try again.");
|
||||
localStorage.removeItem("surfsense_bearer_token");
|
||||
window.location.href = "/";
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,9 +5,15 @@ import {
|
|||
MessagePrimitive,
|
||||
useAssistantState,
|
||||
} from "@assistant-ui/react";
|
||||
import { CheckIcon, CopyIcon, DownloadIcon, RefreshCwIcon } from "lucide-react";
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { CheckIcon, CopyIcon, DownloadIcon, MessageSquare, RefreshCwIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useContext } from "react";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
addingCommentToMessageIdAtom,
|
||||
commentsEnabledAtom,
|
||||
} from "@/atoms/chat/current-thread.atom";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { BranchPicker } from "@/components/assistant-ui/branch-picker";
|
||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||
import {
|
||||
|
|
@ -16,6 +22,12 @@ import {
|
|||
} from "@/components/assistant-ui/thinking-steps";
|
||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container";
|
||||
import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet";
|
||||
import { CommentTrigger } from "@/components/chat-comments/comment-trigger/comment-trigger";
|
||||
import { useComments } from "@/hooks/use-comments";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const MessageError: FC = () => {
|
||||
return (
|
||||
|
|
@ -76,13 +88,142 @@ const AssistantMessageInner: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
function parseMessageId(assistantUiMessageId: string | undefined): number | null {
|
||||
if (!assistantUiMessageId) return null;
|
||||
const match = assistantUiMessageId.match(/^msg-(\d+)$/);
|
||||
return match ? Number.parseInt(match[1], 10) : null;
|
||||
}
|
||||
|
||||
export const AssistantMessage: FC = () => {
|
||||
const [messageHeight, setMessageHeight] = useState<number | undefined>(undefined);
|
||||
const [isSheetOpen, setIsSheetOpen] = useState(false);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
const messageId = useAssistantState(({ message }) => message?.id);
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const dbMessageId = parseMessageId(messageId);
|
||||
const commentsEnabled = useAtomValue(commentsEnabledAtom);
|
||||
const [addingCommentToMessageId, setAddingCommentToMessageId] = useAtom(
|
||||
addingCommentToMessageIdAtom
|
||||
);
|
||||
|
||||
// Screen size detection for responsive comment UI
|
||||
// Mobile: < 768px (bottom sheet), Medium: 768px - 1024px (right sheet), Desktop: >= 1024px (inline panel)
|
||||
const isMediumScreen = useMediaQuery("(min-width: 768px) and (max-width: 1023px)");
|
||||
const isDesktop = useMediaQuery("(min-width: 1024px)");
|
||||
|
||||
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
|
||||
const isLastMessage = useAssistantState(({ message }) => message?.isLast ?? false);
|
||||
const isMessageStreaming = isThreadRunning && isLastMessage;
|
||||
|
||||
const { data: commentsData } = useComments({
|
||||
messageId: dbMessageId ?? 0,
|
||||
enabled: !!dbMessageId,
|
||||
});
|
||||
|
||||
const commentCount = commentsData?.total_count ?? 0;
|
||||
const hasComments = commentCount > 0;
|
||||
const isAddingComment = dbMessageId !== null && addingCommentToMessageId === dbMessageId;
|
||||
const showCommentPanel = hasComments || isAddingComment;
|
||||
|
||||
const handleToggleAddComment = () => {
|
||||
if (!dbMessageId) return;
|
||||
setAddingCommentToMessageId(isAddingComment ? null : dbMessageId);
|
||||
};
|
||||
|
||||
const handleCommentTriggerClick = () => {
|
||||
setIsSheetOpen(true);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!messageRef.current) return;
|
||||
const el = messageRef.current;
|
||||
const update = () => setMessageHeight(el.offsetHeight);
|
||||
update();
|
||||
const observer = new ResizeObserver(update);
|
||||
observer.observe(el);
|
||||
return () => observer.disconnect();
|
||||
}, []);
|
||||
|
||||
const showCommentTrigger = searchSpaceId && commentsEnabled && !isMessageStreaming && dbMessageId;
|
||||
|
||||
// Determine sheet side based on screen size
|
||||
const sheetSide = isMediumScreen ? "right" : "bottom";
|
||||
|
||||
return (
|
||||
<MessagePrimitive.Root
|
||||
className="aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||
ref={messageRef}
|
||||
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||
data-role="assistant"
|
||||
>
|
||||
<AssistantMessageInner />
|
||||
|
||||
{/* Desktop comment panel - only on lg screens and above */}
|
||||
{searchSpaceId && commentsEnabled && !isMessageStreaming && (
|
||||
<div className="absolute left-full top-0 ml-4 hidden lg:block w-72">
|
||||
<div
|
||||
className={`sticky top-3 ${showCommentPanel ? "opacity-100" : "opacity-0 group-hover:opacity-100"} transition-opacity`}
|
||||
>
|
||||
{!hasComments && (
|
||||
<CommentTrigger
|
||||
commentCount={0}
|
||||
isOpen={isAddingComment}
|
||||
onClick={handleToggleAddComment}
|
||||
disabled={!dbMessageId}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showCommentPanel && dbMessageId && (
|
||||
<div
|
||||
className={
|
||||
hasComments ? "" : "mt-2 animate-in fade-in slide-in-from-top-2 duration-200"
|
||||
}
|
||||
>
|
||||
<CommentPanelContainer
|
||||
messageId={dbMessageId}
|
||||
isOpen={true}
|
||||
maxHeight={messageHeight}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Mobile & Medium screen comment trigger - shown below lg breakpoint */}
|
||||
{showCommentTrigger && !isDesktop && (
|
||||
<div className="mt-2 flex justify-start">
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleCommentTriggerClick}
|
||||
className={cn(
|
||||
"flex items-center gap-2 rounded-full px-3 py-1.5 text-sm transition-colors",
|
||||
hasComments
|
||||
? "border border-primary/50 bg-primary/5 text-primary hover:bg-primary/10"
|
||||
: "text-muted-foreground hover:bg-muted hover:text-foreground"
|
||||
)}
|
||||
>
|
||||
<MessageSquare className={cn("size-4", hasComments && "fill-current")} />
|
||||
{hasComments ? (
|
||||
<span>
|
||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||
</span>
|
||||
) : (
|
||||
<span>Add comment</span>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Comment sheet - bottom for mobile, right for medium screens */}
|
||||
{showCommentTrigger && !isDesktop && (
|
||||
<CommentSheet
|
||||
messageId={dbMessageId}
|
||||
isOpen={isSheetOpen}
|
||||
onOpenChange={setIsSheetOpen}
|
||||
commentCount={commentCount}
|
||||
side={sheetSide}
|
||||
/>
|
||||
)}
|
||||
</MessagePrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ export const ComposerAddAttachment: FC = () => {
|
|||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleFileUpload} className="cursor-pointer">
|
||||
<Upload className="size-4" />
|
||||
<span>Upload Files</span>
|
||||
<span>Upload Documents</span>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
|
|
|
|||
|
|
@ -1,300 +0,0 @@
|
|||
import { AssistantIf, ComposerPrimitive, useAssistantState } from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import {
|
||||
AlertCircle,
|
||||
ArrowUpIcon,
|
||||
ChevronRightIcon,
|
||||
Loader2,
|
||||
Plug2,
|
||||
Plus,
|
||||
SquareIcon,
|
||||
} from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { getDocumentTypeLabel } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
llmPreferencesAtom,
|
||||
newLLMConfigsAtom,
|
||||
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { ComposerAddAttachment } from "@/components/assistant-ui/attachment";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { useSearchSourceConnectors } from "@/hooks/use-search-source-connectors";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const ConnectorIndicator: FC = () => {
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const { connectors, isLoading: connectorsLoading } = useSearchSourceConnectors(
|
||||
false,
|
||||
searchSpaceId ? Number(searchSpaceId) : undefined
|
||||
);
|
||||
const { data: documentTypeCounts, isLoading: documentTypesLoading } =
|
||||
useAtomValue(documentTypeCountsAtom);
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const closeTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const isLoading = connectorsLoading || documentTypesLoading;
|
||||
|
||||
const activeDocumentTypes = documentTypeCounts
|
||||
? Object.entries(documentTypeCounts).filter(([_, count]) => count > 0)
|
||||
: [];
|
||||
|
||||
// Count only active connectors (matching what's shown in the Active tab)
|
||||
const activeConnectorsCount = connectors.length;
|
||||
const hasConnectors = activeConnectorsCount > 0;
|
||||
const hasSources = hasConnectors || activeDocumentTypes.length > 0;
|
||||
|
||||
const handleMouseEnter = useCallback(() => {
|
||||
// Clear any pending close timeout
|
||||
if (closeTimeoutRef.current) {
|
||||
clearTimeout(closeTimeoutRef.current);
|
||||
closeTimeoutRef.current = null;
|
||||
}
|
||||
setIsOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleMouseLeave = useCallback(() => {
|
||||
// Delay closing by 150ms for better UX
|
||||
closeTimeoutRef.current = setTimeout(() => {
|
||||
setIsOpen(false);
|
||||
}, 150);
|
||||
}, []);
|
||||
|
||||
if (!searchSpaceId) return null;
|
||||
|
||||
return (
|
||||
<Popover open={isOpen} onOpenChange={setIsOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
"size-[34px] rounded-full p-1 flex items-center justify-center transition-colors relative",
|
||||
"hover:bg-muted-foreground/15 dark:hover:bg-muted-foreground/30",
|
||||
"outline-none focus:outline-none focus-visible:outline-none",
|
||||
"border-0 ring-0 focus:ring-0 shadow-none focus:shadow-none",
|
||||
"data-[state=open]:bg-transparent data-[state=open]:shadow-none data-[state=open]:ring-0",
|
||||
"text-muted-foreground"
|
||||
)}
|
||||
aria-label={
|
||||
hasConnectors
|
||||
? `View ${activeConnectorsCount} active connectors`
|
||||
: "Add your first connector"
|
||||
}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
>
|
||||
{isLoading ? (
|
||||
<Loader2 className="size-4 animate-spin" />
|
||||
) : (
|
||||
<>
|
||||
<Plug2 className="size-4" />
|
||||
{activeConnectorsCount > 0 && (
|
||||
<span className="absolute -top-0.5 -right-0.5 flex items-center justify-center min-w-[16px] h-4 px-1 text-[10px] font-medium rounded-full bg-primary text-primary-foreground shadow-sm">
|
||||
{activeConnectorsCount > 99 ? "99+" : activeConnectorsCount}
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
side="bottom"
|
||||
align="start"
|
||||
className="w-64 p-3"
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
>
|
||||
{hasSources ? (
|
||||
<div className="space-y-3">
|
||||
{activeConnectorsCount > 0 && (
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-xs font-medium text-muted-foreground">Active Connectors</p>
|
||||
<span className="text-xs font-medium bg-muted px-1.5 py-0.5 rounded">
|
||||
{activeConnectorsCount}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{activeConnectorsCount > 0 && (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{connectors.map((connector) => (
|
||||
<div
|
||||
key={`connector-${connector.id}`}
|
||||
className="flex items-center gap-1.5 rounded-md bg-muted/80 px-2.5 py-1.5 text-xs border border-border/50"
|
||||
>
|
||||
{getConnectorIcon(connector.connector_type, "size-3.5")}
|
||||
<span className="truncate max-w-[100px]">{connector.name}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{activeDocumentTypes.length > 0 && (
|
||||
<>
|
||||
{activeConnectorsCount > 0 && (
|
||||
<div className="pt-2 border-t border-border/50">
|
||||
<p className="text-xs font-medium text-muted-foreground mb-2">Documents</p>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{activeDocumentTypes.map(([docType, count]) => (
|
||||
<div
|
||||
key={docType}
|
||||
className="flex items-center gap-1.5 rounded-md bg-muted/80 px-2.5 py-1.5 text-xs border border-border/50"
|
||||
>
|
||||
{getConnectorIcon(docType, "size-3.5")}
|
||||
<span className="truncate max-w-[100px]">
|
||||
{getDocumentTypeLabel(docType)}
|
||||
</span>
|
||||
<span className="flex items-center justify-center min-w-[18px] h-[18px] px-1 text-[10px] font-medium rounded-full bg-primary/10 text-primary">
|
||||
{count > 999 ? "999+" : count}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className="pt-1 border-t border-border/50">
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex items-center gap-1.5 text-xs text-muted-foreground hover:text-foreground transition-colors"
|
||||
onClick={() => {
|
||||
/* Connector popup should be opened via the connector indicator button */
|
||||
}}
|
||||
>
|
||||
<Plus className="size-3" />
|
||||
Add more sources
|
||||
<ChevronRightIcon className="size-3" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<p className="text-sm font-medium">No sources yet</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Add documents or connect data sources to enhance search results.
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex items-center gap-1.5 rounded-md bg-primary px-3 py-1.5 text-xs font-medium text-primary-foreground hover:bg-primary/90 transition-colors mt-1"
|
||||
onClick={() => {
|
||||
/* Connector popup should be opened via the connector indicator button */
|
||||
}}
|
||||
>
|
||||
<Plus className="size-3" />
|
||||
Add Connector
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export const ComposerAction: FC = () => {
|
||||
// Check if any attachments are still being processed (running AND progress < 100)
|
||||
// When progress is 100, processing is done but waiting for send()
|
||||
const hasProcessingAttachments = useAssistantState(({ composer }) =>
|
||||
composer.attachments?.some((att) => {
|
||||
const status = att.status;
|
||||
if (status?.type !== "running") return false;
|
||||
const progress = (status as { type: "running"; progress?: number }).progress;
|
||||
return progress === undefined || progress < 100;
|
||||
})
|
||||
);
|
||||
|
||||
// Check if composer text is empty
|
||||
const isComposerEmpty = useAssistantState(({ composer }) => {
|
||||
const text = composer.text?.trim() || "";
|
||||
return text.length === 0;
|
||||
});
|
||||
|
||||
// Check if a model is configured
|
||||
const { data: userConfigs } = useAtomValue(newLLMConfigsAtom);
|
||||
const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom);
|
||||
const { data: preferences } = useAtomValue(llmPreferencesAtom);
|
||||
|
||||
const hasModelConfigured = useMemo(() => {
|
||||
if (!preferences) return false;
|
||||
const agentLlmId = preferences.agent_llm_id;
|
||||
if (agentLlmId === null || agentLlmId === undefined) return false;
|
||||
|
||||
// Check if the configured model actually exists
|
||||
if (agentLlmId < 0) {
|
||||
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
}
|
||||
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
}, [preferences, globalConfigs, userConfigs]);
|
||||
|
||||
const isSendDisabled = hasProcessingAttachments || isComposerEmpty || !hasModelConfigured;
|
||||
|
||||
return (
|
||||
<div className="aui-composer-action-wrapper relative mx-2 mb-2 flex items-center justify-between">
|
||||
<div className="flex items-center gap-1">
|
||||
<ComposerAddAttachment />
|
||||
<ConnectorIndicator />
|
||||
</div>
|
||||
|
||||
{/* Show processing indicator when attachments are being processed */}
|
||||
{hasProcessingAttachments && (
|
||||
<div className="flex items-center gap-1.5 text-muted-foreground text-xs">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
<span>Processing...</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show warning when no model is configured */}
|
||||
{!hasModelConfigured && !hasProcessingAttachments && (
|
||||
<div className="flex items-center gap-1.5 text-amber-600 dark:text-amber-400 text-xs">
|
||||
<AlertCircle className="size-3" />
|
||||
<span>Select a model</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AssistantIf condition={({ thread }) => !thread.isRunning}>
|
||||
<ComposerPrimitive.Send asChild disabled={isSendDisabled}>
|
||||
<TooltipIconButton
|
||||
tooltip={
|
||||
!hasModelConfigured
|
||||
? "Please select a model from the header to start chatting"
|
||||
: hasProcessingAttachments
|
||||
? "Wait for attachments to process"
|
||||
: isComposerEmpty
|
||||
? "Enter a message to send"
|
||||
: "Send message"
|
||||
}
|
||||
side="bottom"
|
||||
type="submit"
|
||||
variant="default"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"aui-composer-send size-8 rounded-full",
|
||||
isSendDisabled && "cursor-not-allowed opacity-50"
|
||||
)}
|
||||
aria-label="Send message"
|
||||
disabled={isSendDisabled}
|
||||
>
|
||||
<ArrowUpIcon className="aui-composer-send-icon size-4" />
|
||||
</TooltipIconButton>
|
||||
</ComposerPrimitive.Send>
|
||||
</AssistantIf>
|
||||
|
||||
<AssistantIf condition={({ thread }) => thread.isRunning}>
|
||||
<ComposerPrimitive.Cancel asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="default"
|
||||
size="icon"
|
||||
className="aui-composer-cancel size-8 rounded-full"
|
||||
aria-label="Stop generating"
|
||||
>
|
||||
<SquareIcon className="aui-composer-cancel-icon size-3 fill-current" />
|
||||
</Button>
|
||||
</ComposerPrimitive.Cancel>
|
||||
</AssistantIf>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
import { ComposerPrimitive, useAssistantState, useComposerRuntime } from "@assistant-ui/react";
|
||||
import { useAtom, useSetAtom } from "jotai";
|
||||
import { useParams } from "next/navigation";
|
||||
import type { FC } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import {
|
||||
mentionedDocumentIdsAtom,
|
||||
mentionedDocumentsAtom,
|
||||
} from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { ComposerAddAttachment, ComposerAttachments } from "@/components/assistant-ui/attachment";
|
||||
import { ComposerAction } from "@/components/assistant-ui/composer-action";
|
||||
import {
|
||||
InlineMentionEditor,
|
||||
type InlineMentionEditorRef,
|
||||
} from "@/components/assistant-ui/inline-mention-editor";
|
||||
import {
|
||||
DocumentMentionPicker,
|
||||
type DocumentMentionPickerRef,
|
||||
} from "@/components/new-chat/document-mention-picker";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
|
||||
export const Composer: FC = () => {
|
||||
// ---- State for document mentions (using atoms to persist across remounts) ----
|
||||
const [mentionedDocuments, setMentionedDocuments] = useAtom(mentionedDocumentsAtom);
|
||||
const [showDocumentPopover, setShowDocumentPopover] = useState(false);
|
||||
const [mentionQuery, setMentionQuery] = useState("");
|
||||
const editorRef = useRef<InlineMentionEditorRef>(null);
|
||||
const editorContainerRef = useRef<HTMLDivElement>(null);
|
||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||
const { search_space_id } = useParams();
|
||||
const setMentionedDocumentIds = useSetAtom(mentionedDocumentIdsAtom);
|
||||
const composerRuntime = useComposerRuntime();
|
||||
const hasAutoFocusedRef = useRef(false);
|
||||
|
||||
// Check if thread is empty (new chat)
|
||||
const isThreadEmpty = useAssistantState(({ thread }) => thread.isEmpty);
|
||||
|
||||
// Check if thread is currently running (streaming response)
|
||||
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
|
||||
|
||||
// Auto-focus editor when on new chat page
|
||||
useEffect(() => {
|
||||
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {
|
||||
// Small delay to ensure the editor is fully mounted
|
||||
const timeoutId = setTimeout(() => {
|
||||
editorRef.current?.focus();
|
||||
hasAutoFocusedRef.current = true;
|
||||
}, 100);
|
||||
return () => clearTimeout(timeoutId);
|
||||
}
|
||||
}, [isThreadEmpty]);
|
||||
|
||||
// Sync mentioned document IDs to atom for use in chat request
|
||||
useEffect(() => {
|
||||
setMentionedDocumentIds({
|
||||
surfsense_doc_ids: mentionedDocuments
|
||||
.filter((doc) => doc.document_type === "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
document_ids: mentionedDocuments
|
||||
.filter((doc) => doc.document_type !== "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
});
|
||||
}, [mentionedDocuments, setMentionedDocumentIds]);
|
||||
|
||||
// Handle text change from inline editor - sync with assistant-ui composer
|
||||
const handleEditorChange = useCallback(
|
||||
(text: string) => {
|
||||
composerRuntime.setText(text);
|
||||
},
|
||||
[composerRuntime]
|
||||
);
|
||||
|
||||
// Handle @ mention trigger from inline editor
|
||||
const handleMentionTrigger = useCallback((query: string) => {
|
||||
setShowDocumentPopover(true);
|
||||
setMentionQuery(query);
|
||||
}, []);
|
||||
|
||||
// Handle mention close
|
||||
const handleMentionClose = useCallback(() => {
|
||||
if (showDocumentPopover) {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}
|
||||
}, [showDocumentPopover]);
|
||||
|
||||
// Handle keyboard navigation when popover is open
|
||||
const handleKeyDown = useCallback(
|
||||
(e: React.KeyboardEvent) => {
|
||||
if (showDocumentPopover) {
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
documentPickerRef.current?.moveDown();
|
||||
return;
|
||||
}
|
||||
if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
documentPickerRef.current?.moveUp();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
documentPickerRef.current?.selectHighlighted();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Escape") {
|
||||
e.preventDefault();
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
return;
|
||||
}
|
||||
}
|
||||
},
|
||||
[showDocumentPopover]
|
||||
);
|
||||
|
||||
// Handle submit from inline editor (Enter key)
|
||||
const handleSubmit = useCallback(() => {
|
||||
// Prevent sending while a response is still streaming
|
||||
if (isThreadRunning) {
|
||||
return;
|
||||
}
|
||||
if (!showDocumentPopover) {
|
||||
composerRuntime.send();
|
||||
// Clear the editor after sending
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
setMentionedDocumentIds({
|
||||
surfsense_doc_ids: [],
|
||||
document_ids: [],
|
||||
});
|
||||
}
|
||||
}, [
|
||||
showDocumentPopover,
|
||||
isThreadRunning,
|
||||
composerRuntime,
|
||||
setMentionedDocuments,
|
||||
setMentionedDocumentIds,
|
||||
]);
|
||||
|
||||
const handleDocumentRemove = useCallback(
|
||||
(docId: number, docType?: string) => {
|
||||
setMentionedDocuments((prev) => {
|
||||
const updated = prev.filter((doc) => !(doc.id === docId && doc.document_type === docType));
|
||||
setMentionedDocumentIds({
|
||||
surfsense_doc_ids: updated
|
||||
.filter((doc) => doc.document_type === "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
document_ids: updated
|
||||
.filter((doc) => doc.document_type !== "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
},
|
||||
[setMentionedDocuments, setMentionedDocumentIds]
|
||||
);
|
||||
|
||||
const handleDocumentsMention = useCallback(
|
||||
(documents: Pick<Document, "id" | "title" | "document_type">[]) => {
|
||||
const existingKeys = new Set(mentionedDocuments.map((d) => `${d.document_type}:${d.id}`));
|
||||
const newDocs = documents.filter(
|
||||
(doc) => !existingKeys.has(`${doc.document_type}:${doc.id}`)
|
||||
);
|
||||
|
||||
for (const doc of newDocs) {
|
||||
editorRef.current?.insertDocumentChip(doc);
|
||||
}
|
||||
|
||||
setMentionedDocuments((prev) => {
|
||||
const existingKeySet = new Set(prev.map((d) => `${d.document_type}:${d.id}`));
|
||||
const uniqueNewDocs = documents.filter(
|
||||
(doc) => !existingKeySet.has(`${doc.document_type}:${doc.id}`)
|
||||
);
|
||||
const updated = [...prev, ...uniqueNewDocs];
|
||||
setMentionedDocumentIds({
|
||||
surfsense_doc_ids: updated
|
||||
.filter((doc) => doc.document_type === "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
document_ids: updated
|
||||
.filter((doc) => doc.document_type !== "SURFSENSE_DOCS")
|
||||
.map((doc) => doc.id),
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
|
||||
setMentionQuery("");
|
||||
},
|
||||
[mentionedDocuments, setMentionedDocuments, setMentionedDocumentIds]
|
||||
);
|
||||
|
||||
return (
|
||||
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col">
|
||||
<ComposerPrimitive.AttachmentDropzone className="aui-composer-attachment-dropzone flex w-full flex-col rounded-2xl border-input bg-muted px-1 pt-2 outline-none transition-shadow data-[dragging=true]:border-ring data-[dragging=true]:border-dashed data-[dragging=true]:bg-accent/50">
|
||||
<ComposerAttachments />
|
||||
{/* -------- Inline Mention Editor -------- */}
|
||||
<div ref={editorContainerRef} className="aui-composer-input-wrapper px-3 pt-3 pb-6">
|
||||
<InlineMentionEditor
|
||||
ref={editorRef}
|
||||
placeholder="Ask SurfSense or @mention docs"
|
||||
onMentionTrigger={handleMentionTrigger}
|
||||
onMentionClose={handleMentionClose}
|
||||
onChange={handleEditorChange}
|
||||
onDocumentRemove={handleDocumentRemove}
|
||||
onSubmit={handleSubmit}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="min-h-[24px]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* -------- Document mention popover (rendered via portal) -------- */}
|
||||
{showDocumentPopover &&
|
||||
typeof document !== "undefined" &&
|
||||
createPortal(
|
||||
<>
|
||||
{/* Backdrop */}
|
||||
<button
|
||||
type="button"
|
||||
className="fixed inset-0 cursor-default"
|
||||
style={{ zIndex: 9998 }}
|
||||
onClick={() => setShowDocumentPopover(false)}
|
||||
aria-label="Close document picker"
|
||||
/>
|
||||
{/* Popover positioned above input */}
|
||||
<div
|
||||
className="fixed shadow-2xl rounded-lg border border-border overflow-hidden bg-popover"
|
||||
style={{
|
||||
zIndex: 9999,
|
||||
bottom: editorContainerRef.current
|
||||
? `${window.innerHeight - editorContainerRef.current.getBoundingClientRect().top + 8}px`
|
||||
: "200px",
|
||||
left: editorContainerRef.current
|
||||
? `${editorContainerRef.current.getBoundingClientRect().left}px`
|
||||
: "50%",
|
||||
}}
|
||||
>
|
||||
<DocumentMentionPicker
|
||||
ref={documentPickerRef}
|
||||
searchSpaceId={Number(search_space_id)}
|
||||
onSelectionChange={handleDocumentsMention}
|
||||
onDone={() => {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}}
|
||||
initialSelectedDocuments={mentionedDocuments}
|
||||
externalSearch={mentionQuery}
|
||||
/>
|
||||
</div>
|
||||
</>,
|
||||
document.body
|
||||
)}
|
||||
<ComposerAction />
|
||||
</ComposerPrimitive.AttachmentDropzone>
|
||||
</ComposerPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,19 +1,16 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Cable, Loader2 } from "lucide-react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { type FC, useEffect, useMemo } from "react";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import type { FC } from "react";
|
||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { Dialog, DialogContent } from "@/components/ui/dialog";
|
||||
import { Tabs, TabsContent } from "@/components/ui/tabs";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import { useLogsSummary } from "@/hooks/use-logs";
|
||||
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { useConnectorsElectric } from "@/hooks/use-connectors-electric";
|
||||
import { useDocumentsElectric } from "@/hooks/use-documents-electric";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
|
||||
import { ConnectorConnectView } from "./connector-popup/connector-configs/views/connector-connect-view";
|
||||
|
|
@ -21,6 +18,7 @@ import { ConnectorEditView } from "./connector-popup/connector-configs/views/con
|
|||
import { IndexingConfigurationView } from "./connector-popup/connector-configs/views/indexing-configuration-view";
|
||||
import { OAUTH_CONNECTORS } from "./connector-popup/constants/connector-constants";
|
||||
import { useConnectorDialog } from "./connector-popup/hooks/use-connector-dialog";
|
||||
import { useIndexingConnectors } from "./connector-popup/hooks/use-indexing-connectors";
|
||||
import { ActiveConnectorsTab } from "./connector-popup/tabs/active-connectors-tab";
|
||||
import { AllConnectorsTab } from "./connector-popup/tabs/all-connectors-tab";
|
||||
import { ConnectorAccountsListView } from "./connector-popup/views/connector-accounts-list-view";
|
||||
|
|
@ -29,18 +27,13 @@ import { YouTubeCrawlerView } from "./connector-popup/views/youtube-crawler-view
|
|||
export const ConnectorIndicator: FC = () => {
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const searchParams = useSearchParams();
|
||||
const { data: documentTypeCounts, isLoading: documentTypesLoading } =
|
||||
useAtomValue(documentTypeCountsAtom);
|
||||
|
||||
// Fetch document type counts using Electric SQL + PGlite for real-time updates
|
||||
const { documentTypeCounts, loading: documentTypesLoading } = useDocumentsElectric(searchSpaceId);
|
||||
|
||||
// Check if YouTube view is active
|
||||
const isYouTubeView = searchParams.get("view") === "youtube";
|
||||
|
||||
// Track active indexing tasks
|
||||
const { summary: logsSummary } = useLogsSummary(searchSpaceId ? Number(searchSpaceId) : 0, 24, {
|
||||
enablePolling: true,
|
||||
refetchInterval: 5000,
|
||||
});
|
||||
|
||||
// Use the custom hook for dialog state management
|
||||
const {
|
||||
isOpen,
|
||||
|
|
@ -63,6 +56,7 @@ export const ConnectorIndicator: FC = () => {
|
|||
frequencyMinutes,
|
||||
allConnectors,
|
||||
viewingAccountsType,
|
||||
viewingMCPList,
|
||||
setSearchQuery,
|
||||
setStartDate,
|
||||
setEndDate,
|
||||
|
|
@ -86,6 +80,8 @@ export const ConnectorIndicator: FC = () => {
|
|||
handleBackFromYouTube,
|
||||
handleViewAccountsList,
|
||||
handleBackFromAccountsList,
|
||||
handleBackFromMCPList,
|
||||
handleAddNewMCPFromList,
|
||||
handleQuickIndexConnector,
|
||||
connectorConfig,
|
||||
setConnectorConfig,
|
||||
|
|
@ -93,57 +89,36 @@ export const ConnectorIndicator: FC = () => {
|
|||
setConnectorName,
|
||||
} = useConnectorDialog();
|
||||
|
||||
// Fetch connectors using React Query with conditional refetchInterval
|
||||
// This automatically refetches when mutations invalidate the cache (event-driven)
|
||||
// and also polls when dialog is open to catch external changes
|
||||
// Fetch connectors using Electric SQL + PGlite for real-time updates
|
||||
// This provides instant updates when connectors change, without polling
|
||||
const {
|
||||
data: connectors = [],
|
||||
isLoading: connectorsLoading,
|
||||
refetch: refreshConnectors,
|
||||
} = useQuery({
|
||||
queryKey: cacheKeys.connectors.all(searchSpaceId || ""),
|
||||
queryFn: () =>
|
||||
connectorsApiService.getConnectors({
|
||||
queryParams: {
|
||||
search_space_id: searchSpaceId ? Number(searchSpaceId) : undefined,
|
||||
},
|
||||
}),
|
||||
enabled: !!searchSpaceId,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes (same as connectorsAtom)
|
||||
// Poll when dialog is open to catch external changes
|
||||
refetchInterval: isOpen ? 5000 : false, // 5 seconds when open, no polling when closed
|
||||
});
|
||||
connectors: connectorsFromElectric = [],
|
||||
loading: connectorsLoading,
|
||||
error: connectorsError,
|
||||
refreshConnectors: refreshConnectorsElectric,
|
||||
} = useConnectorsElectric(searchSpaceId);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
// Fallback to API if Electric is not available or fails
|
||||
// Use Electric data if: 1) we have data, or 2) still loading without error
|
||||
// Use API data if: Electric failed (has error) or finished loading with no data
|
||||
const useElectricData =
|
||||
connectorsFromElectric.length > 0 || (connectorsLoading && !connectorsError);
|
||||
const connectors = useElectricData ? connectorsFromElectric : allConnectors || [];
|
||||
|
||||
// Also refresh document type counts when dialog is open
|
||||
useEffect(() => {
|
||||
if (!isOpen || !searchSpaceId) return;
|
||||
// Manual refresh function that works with both Electric and API
|
||||
const refreshConnectors = async () => {
|
||||
if (useElectricData) {
|
||||
await refreshConnectorsElectric();
|
||||
} else {
|
||||
// Fallback: use allConnectors from useConnectorDialog (which uses connectorsAtom)
|
||||
// The connectorsAtom will handle refetching if needed
|
||||
}
|
||||
};
|
||||
|
||||
const POLL_INTERVAL = 5000; // 5 seconds, same as connectors
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
// Invalidate document type counts to refresh active document types
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.documents.typeCounts(searchSpaceId),
|
||||
});
|
||||
}, POLL_INTERVAL);
|
||||
|
||||
// Cleanup interval on unmount or when dialog closes
|
||||
return () => {
|
||||
clearInterval(intervalId);
|
||||
};
|
||||
}, [isOpen, searchSpaceId, queryClient]);
|
||||
|
||||
// Get connector IDs that are currently being indexed
|
||||
const indexingConnectorIds = useMemo(() => {
|
||||
if (!logsSummary?.active_tasks) return new Set<number>();
|
||||
return new Set(
|
||||
logsSummary.active_tasks
|
||||
.filter((task) => task.source?.includes("connector_indexing") && task.connector_id != null)
|
||||
.map((task) => task.connector_id as number)
|
||||
);
|
||||
}, [logsSummary?.active_tasks]);
|
||||
// Track indexing state locally - clears automatically when Electric SQL detects last_indexed_at changed
|
||||
const { indexingConnectorIds, startIndexing } = useIndexingConnectors(
|
||||
connectors as SearchSourceConnector[]
|
||||
);
|
||||
|
||||
const isLoading = connectorsLoading || documentTypesLoading;
|
||||
|
||||
|
|
@ -155,11 +130,13 @@ export const ConnectorIndicator: FC = () => {
|
|||
const hasConnectors = connectors.length > 0;
|
||||
const hasSources = hasConnectors || activeDocumentTypes.length > 0;
|
||||
const totalSourceCount = connectors.length + activeDocumentTypes.length;
|
||||
const activeConnectorsCount = connectors.length; // Only actual connectors, not document types
|
||||
|
||||
const activeConnectorsCount = connectors.length;
|
||||
|
||||
// Check which connectors are already connected
|
||||
// Using Electric SQL + PGlite for real-time connector updates
|
||||
const connectedTypes = new Set(
|
||||
(allConnectors || []).map((c: SearchSourceConnector) => c.connector_type)
|
||||
(connectors || []).map((c: SearchSourceConnector) => c.connector_type)
|
||||
);
|
||||
|
||||
if (!searchSpaceId) return null;
|
||||
|
|
@ -199,13 +176,23 @@ export const ConnectorIndicator: FC = () => {
|
|||
{/* YouTube Crawler View - shown when adding YouTube videos */}
|
||||
{isYouTubeView && searchSpaceId ? (
|
||||
<YouTubeCrawlerView searchSpaceId={searchSpaceId} onBack={handleBackFromYouTube} />
|
||||
) : viewingMCPList ? (
|
||||
<ConnectorAccountsListView
|
||||
connectorType="MCP_CONNECTOR"
|
||||
connectorTitle="MCP Connectors"
|
||||
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
onBack={handleBackFromMCPList}
|
||||
onManage={handleStartEdit}
|
||||
onAddAccount={handleAddNewMCPFromList}
|
||||
addButtonText="Add New MCP Server"
|
||||
/>
|
||||
) : viewingAccountsType ? (
|
||||
<ConnectorAccountsListView
|
||||
connectorType={viewingAccountsType.connectorType}
|
||||
connectorTitle={viewingAccountsType.connectorTitle}
|
||||
connectors={(allConnectors || []) as SearchSourceConnector[]}
|
||||
connectors={(connectors || []) as SearchSourceConnector[]} // Using Electric SQL + PGlite for real-time connector updates (all connector types)
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
logsSummary={logsSummary}
|
||||
onBack={handleBackFromAccountsList}
|
||||
onManage={handleStartEdit}
|
||||
onAddAccount={() => {
|
||||
|
|
@ -221,7 +208,7 @@ export const ConnectorIndicator: FC = () => {
|
|||
) : connectingConnectorType ? (
|
||||
<ConnectorConnectView
|
||||
connectorType={connectingConnectorType}
|
||||
onSubmit={handleSubmitConnectForm}
|
||||
onSubmit={(formData) => handleSubmitConnectForm(formData, startIndexing)}
|
||||
onBack={handleBackFromConnect}
|
||||
isSubmitting={isCreatingConnector}
|
||||
/>
|
||||
|
|
@ -239,17 +226,23 @@ export const ConnectorIndicator: FC = () => {
|
|||
isSaving={isSaving}
|
||||
isDisconnecting={isDisconnecting}
|
||||
isIndexing={indexingConnectorIds.has(editingConnector.id)}
|
||||
searchSpaceId={searchSpaceId?.toString()}
|
||||
onStartDateChange={setStartDate}
|
||||
onEndDateChange={setEndDate}
|
||||
onPeriodicEnabledChange={setPeriodicEnabled}
|
||||
onFrequencyChange={setFrequencyMinutes}
|
||||
onSave={() => handleSaveConnector(() => refreshConnectors())}
|
||||
onSave={() => {
|
||||
startIndexing(editingConnector.id);
|
||||
handleSaveConnector(() => refreshConnectors());
|
||||
}}
|
||||
onDisconnect={() => handleDisconnectConnector(() => refreshConnectors())}
|
||||
onBack={handleBackFromEdit}
|
||||
onQuickIndex={
|
||||
editingConnector.connector_type !== "GOOGLE_DRIVE_CONNECTOR"
|
||||
? () =>
|
||||
handleQuickIndexConnector(editingConnector.id, editingConnector.connector_type)
|
||||
? () => {
|
||||
startIndexing(editingConnector.id);
|
||||
handleQuickIndexConnector(editingConnector.id, editingConnector.connector_type);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onConfigChange={setConnectorConfig}
|
||||
|
|
@ -276,7 +269,12 @@ export const ConnectorIndicator: FC = () => {
|
|||
onPeriodicEnabledChange={setPeriodicEnabled}
|
||||
onFrequencyChange={setFrequencyMinutes}
|
||||
onConfigChange={setIndexingConnectorConfig}
|
||||
onStartIndexing={() => handleStartIndexing(() => refreshConnectors())}
|
||||
onStartIndexing={() => {
|
||||
if (indexingConfig.connectorId) {
|
||||
startIndexing(indexingConfig.connectorId);
|
||||
}
|
||||
handleStartIndexing(() => refreshConnectors());
|
||||
}}
|
||||
onSkip={handleSkipIndexing}
|
||||
/>
|
||||
) : (
|
||||
|
|
@ -305,10 +303,9 @@ export const ConnectorIndicator: FC = () => {
|
|||
searchSpaceId={searchSpaceId}
|
||||
connectedTypes={connectedTypes}
|
||||
connectingId={connectingId}
|
||||
allConnectors={allConnectors}
|
||||
allConnectors={connectors}
|
||||
documentTypeCounts={documentTypeCounts}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
logsSummary={logsSummary}
|
||||
onConnectOAuth={handleConnectOAuth}
|
||||
onConnectNonOAuth={handleConnectNonOAuth}
|
||||
onCreateWebcrawler={handleCreateWebcrawler}
|
||||
|
|
@ -325,7 +322,6 @@ export const ConnectorIndicator: FC = () => {
|
|||
activeDocumentTypes={activeDocumentTypes}
|
||||
connectors={connectors as SearchSourceConnector[]}
|
||||
indexingConnectorIds={indexingConnectorIds}
|
||||
logsSummary={logsSummary}
|
||||
searchSpaceId={searchSpaceId}
|
||||
onTabChange={handleTabChange}
|
||||
onManage={handleStartEdit}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"use client";
|
||||
|
||||
import { IconBrandYoutube } from "@tabler/icons-react";
|
||||
import { differenceInDays, differenceInMinutes, format, isToday, isYesterday } from "date-fns";
|
||||
import { FileText, Loader2 } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { LogActiveTask } from "@/contracts/types/log.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||
import { ConnectorStatusBadge } from "./connector-status-badge";
|
||||
|
|
@ -20,24 +19,12 @@ interface ConnectorCardProps {
|
|||
isConnecting?: boolean;
|
||||
documentCount?: number;
|
||||
accountCount?: number;
|
||||
lastIndexedAt?: string | null;
|
||||
connectorCount?: number;
|
||||
isIndexing?: boolean;
|
||||
activeTask?: LogActiveTask;
|
||||
onConnect?: () => void;
|
||||
onManage?: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a number from the active task message for display
|
||||
* Looks for patterns like "45 indexed", "Processing 123", etc.
|
||||
*/
|
||||
function extractIndexedCount(message: string | undefined): number | null {
|
||||
if (!message) return null;
|
||||
// Try to find a number in the message
|
||||
const match = message.match(/(\d+)/);
|
||||
return match ? parseInt(match[1], 10) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format document count (e.g., "1.2k docs", "500 docs", "1.5M docs")
|
||||
*/
|
||||
|
|
@ -52,45 +39,6 @@ function formatDocumentCount(count: number | undefined): string {
|
|||
return `${m.replace(/\.0$/, "")}M docs`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format last indexed date with contextual messages
|
||||
* Examples: "Just now", "10 minutes ago", "Today at 2:30 PM", "Yesterday at 3:45 PM", "3 days ago", "Jan 15, 2026"
|
||||
*/
|
||||
function formatLastIndexedDate(dateString: string): string {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const minutesAgo = differenceInMinutes(now, date);
|
||||
const daysAgo = differenceInDays(now, date);
|
||||
|
||||
// Just now (within last minute)
|
||||
if (minutesAgo < 1) {
|
||||
return "Just now";
|
||||
}
|
||||
|
||||
// X minutes ago (less than 1 hour)
|
||||
if (minutesAgo < 60) {
|
||||
return `${minutesAgo} ${minutesAgo === 1 ? "minute" : "minutes"} ago`;
|
||||
}
|
||||
|
||||
// Today at [time]
|
||||
if (isToday(date)) {
|
||||
return `Today at ${format(date, "h:mm a")}`;
|
||||
}
|
||||
|
||||
// Yesterday at [time]
|
||||
if (isYesterday(date)) {
|
||||
return `Yesterday at ${format(date, "h:mm a")}`;
|
||||
}
|
||||
|
||||
// X days ago (less than 7 days)
|
||||
if (daysAgo < 7) {
|
||||
return `${daysAgo} ${daysAgo === 1 ? "day" : "days"} ago`;
|
||||
}
|
||||
|
||||
// Full date for older entries
|
||||
return format(date, "MMM d, yyyy");
|
||||
}
|
||||
|
||||
export const ConnectorCard: FC<ConnectorCardProps> = ({
|
||||
id,
|
||||
title,
|
||||
|
|
@ -100,12 +48,12 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
|
|||
isConnecting = false,
|
||||
documentCount,
|
||||
accountCount,
|
||||
lastIndexedAt,
|
||||
connectorCount,
|
||||
isIndexing = false,
|
||||
activeTask,
|
||||
onConnect,
|
||||
onManage,
|
||||
}) => {
|
||||
const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR;
|
||||
// Get connector status
|
||||
const { getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings } =
|
||||
useConnectorStatus();
|
||||
|
|
@ -115,36 +63,11 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
|
|||
const statusMessage = getConnectorStatusMessage(connectorType);
|
||||
const showWarnings = shouldShowWarnings();
|
||||
|
||||
// Extract count from active task message during indexing
|
||||
const indexingCount = extractIndexedCount(activeTask?.message);
|
||||
|
||||
// Determine the status content to display
|
||||
const getStatusContent = () => {
|
||||
if (isIndexing) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 w-full max-w-[200px]">
|
||||
<span className="text-[11px] text-primary font-medium whitespace-nowrap">
|
||||
{indexingCount !== null ? <>{indexingCount.toLocaleString()} indexed</> : "Syncing..."}
|
||||
</span>
|
||||
{/* Indeterminate progress bar with animation */}
|
||||
<div className="relative flex-1 h-1 overflow-hidden rounded-full bg-primary/20">
|
||||
<div className="absolute h-full bg-primary rounded-full animate-progress-indeterminate" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isConnected) {
|
||||
// Show last indexed date for connected connectors
|
||||
if (lastIndexedAt) {
|
||||
return (
|
||||
<span className="whitespace-nowrap text-[10px]">
|
||||
Last indexed: {formatLastIndexedDate(lastIndexedAt)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
// Fallback for connected but never indexed
|
||||
return <span className="whitespace-nowrap text-[10px]">Never indexed</span>;
|
||||
// Don't show last indexed in overview tabs - only show in accounts list view
|
||||
return null;
|
||||
}
|
||||
|
||||
return description;
|
||||
|
|
@ -186,19 +109,33 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
|
|||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-[10px] text-muted-foreground mt-1">{getStatusContent()}</div>
|
||||
{isConnected && documentCount !== undefined && (
|
||||
<p className="text-[10px] text-muted-foreground mt-0.5 flex items-center gap-1.5">
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
{accountCount !== undefined && accountCount > 0 && (
|
||||
{isIndexing ? (
|
||||
<p className="text-[11px] text-primary mt-1 flex items-center gap-1.5">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Syncing
|
||||
</p>
|
||||
) : isConnected ? (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 flex items-center gap-1.5">
|
||||
{isMCP && connectorCount !== undefined ? (
|
||||
<span>
|
||||
{connectorCount} {connectorCount === 1 ? "server" : "servers"}
|
||||
</span>
|
||||
) : (
|
||||
<>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
{accountCount !== undefined && accountCount > 0 && (
|
||||
<>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</p>
|
||||
) : (
|
||||
<div className="text-[10px] text-muted-foreground mt-1">{getStatusContent()}</div>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ export const ConnectorStatusBadge: FC<ConnectorStatusBadgeProps> = ({
|
|||
case "deprecated":
|
||||
return {
|
||||
icon: AlertTriangle,
|
||||
className: "ext-slate-500 dark:text-slate-400",
|
||||
className: "text-slate-500 dark:text-slate-400",
|
||||
defaultTitle: "Deprecated",
|
||||
};
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
|
|
@ -16,6 +17,8 @@ interface PeriodicSyncConfigProps {
|
|||
frequencyMinutes: string;
|
||||
onEnabledChange: (enabled: boolean) => void;
|
||||
onFrequencyChange: (frequency: string) => void;
|
||||
disabled?: boolean;
|
||||
disabledMessage?: string;
|
||||
}
|
||||
|
||||
export const PeriodicSyncConfig: FC<PeriodicSyncConfigProps> = ({
|
||||
|
|
@ -23,6 +26,8 @@ export const PeriodicSyncConfig: FC<PeriodicSyncConfigProps> = ({
|
|||
frequencyMinutes,
|
||||
onEnabledChange,
|
||||
onFrequencyChange,
|
||||
disabled = false,
|
||||
disabledMessage,
|
||||
}) => {
|
||||
return (
|
||||
<div className="rounded-xl bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6">
|
||||
|
|
@ -33,9 +38,17 @@ export const PeriodicSyncConfig: FC<PeriodicSyncConfigProps> = ({
|
|||
Automatically re-index at regular intervals
|
||||
</p>
|
||||
</div>
|
||||
<Switch checked={enabled} onCheckedChange={onEnabledChange} />
|
||||
<Switch checked={enabled} onCheckedChange={onEnabledChange} disabled={disabled} />
|
||||
</div>
|
||||
|
||||
{/* Show disabled message when periodic sync can't be enabled */}
|
||||
{disabled && disabledMessage && (
|
||||
<div className="mt-3 flex items-start gap-2 text-amber-600 dark:text-amber-400">
|
||||
<AlertCircle className="size-4 mt-0.5 shrink-0" />
|
||||
<p className="text-xs sm:text-sm">{disabledMessage}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{enabled && (
|
||||
<div className="mt-4 pt-4 border-t border-slate-400/20 space-y-3">
|
||||
<div className="space-y-2">
|
||||
|
|
|
|||
|
|
@ -0,0 +1,288 @@
|
|||
"use client";
|
||||
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
||||
import { type FC, useRef, useState } from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import {
|
||||
extractServerName,
|
||||
type MCPConnectionTestResult,
|
||||
parseMCPConfig,
|
||||
testMCPConnection,
|
||||
} from "../../utils/mcp-config-validator";
|
||||
import type { ConnectFormProps } from "..";
|
||||
|
||||
export const MCPConnectForm: FC<ConnectFormProps> = ({ onSubmit, isSubmitting }) => {
|
||||
const isSubmittingRef = useRef(false);
|
||||
const [configJson, setConfigJson] = useState("");
|
||||
const [jsonError, setJsonError] = useState<string | null>(null);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [showDetails, setShowDetails] = useState(false);
|
||||
const [testResult, setTestResult] = useState<MCPConnectionTestResult | null>(null);
|
||||
|
||||
// Default config for stdio transport (local process)
|
||||
const DEFAULT_STDIO_CONFIG = JSON.stringify(
|
||||
{
|
||||
name: "My MCP Server",
|
||||
command: "npx",
|
||||
args: ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/directory"],
|
||||
env: {
|
||||
API_KEY: "your_api_key_here",
|
||||
},
|
||||
transport: "stdio",
|
||||
},
|
||||
null,
|
||||
2
|
||||
);
|
||||
|
||||
// Default config for HTTP transport (remote server)
|
||||
const DEFAULT_HTTP_CONFIG = JSON.stringify(
|
||||
{
|
||||
name: "My Remote MCP Server",
|
||||
url: "https://your-mcp-server.com/mcp",
|
||||
headers: {
|
||||
API_KEY: "your_api_key_here",
|
||||
},
|
||||
transport: "streamable-http",
|
||||
},
|
||||
null,
|
||||
2
|
||||
);
|
||||
|
||||
const DEFAULT_CONFIG = DEFAULT_STDIO_CONFIG;
|
||||
|
||||
const parseConfig = () => {
|
||||
const result = parseMCPConfig(configJson);
|
||||
if (result.error) {
|
||||
setJsonError(result.error);
|
||||
} else {
|
||||
setJsonError(null);
|
||||
}
|
||||
return result.config;
|
||||
};
|
||||
|
||||
const handleConfigChange = (value: string) => {
|
||||
setConfigJson(value);
|
||||
|
||||
// Clear previous error
|
||||
if (jsonError) {
|
||||
setJsonError(null);
|
||||
}
|
||||
|
||||
// Validate immediately to show errors as user types (with debouncing via parseMCPConfig cache)
|
||||
if (value.trim()) {
|
||||
const result = parseMCPConfig(value);
|
||||
if (result.error) {
|
||||
setJsonError(result.error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleTestConnection = async () => {
|
||||
const serverConfig = parseConfig();
|
||||
if (!serverConfig) {
|
||||
setTestResult({
|
||||
status: "error",
|
||||
message: jsonError || "Invalid configuration",
|
||||
tools: [],
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setIsTesting(true);
|
||||
setTestResult(null);
|
||||
|
||||
const result = await testMCPConnection(serverConfig);
|
||||
setTestResult(result);
|
||||
setIsTesting(false);
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
|
||||
// Prevent multiple submissions
|
||||
if (isSubmittingRef.current || isSubmitting) {
|
||||
return;
|
||||
}
|
||||
|
||||
const serverConfig = parseConfig();
|
||||
if (!serverConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract server name from config if provided
|
||||
const serverName = extractServerName(configJson);
|
||||
|
||||
isSubmittingRef.current = true;
|
||||
try {
|
||||
await onSubmit({
|
||||
name: serverName,
|
||||
connector_type: EnumConnectorName.MCP_CONNECTOR,
|
||||
config: { server_config: serverConfig },
|
||||
is_indexable: false,
|
||||
is_active: true,
|
||||
last_indexed_at: null,
|
||||
periodic_indexing_enabled: false,
|
||||
indexing_frequency_minutes: null,
|
||||
next_scheduled_at: null,
|
||||
});
|
||||
} finally {
|
||||
isSubmittingRef.current = false;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6 pb-6">
|
||||
<Alert className="bg-slate-400/5 dark:bg-white/5 border-slate-400/20 p-2 sm:p-3 [&>svg]:top-2 sm:[&>svg]:top-3">
|
||||
<Server className="h-4 w-4 shrink-0" />
|
||||
<AlertDescription className="text-[10px] sm:text-xs">
|
||||
Connect to an MCP (Model Context Protocol) server. Each MCP server is added as a separate
|
||||
connector.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<form id="mcp-connect-form" onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-4 sm:p-6 space-y-4">
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between flex-wrap gap-2">
|
||||
<Label htmlFor="config">MCP Server Configuration (JSON)</Label>
|
||||
{!configJson && (
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 text-xs text-muted-foreground hover:text-foreground"
|
||||
onClick={() => handleConfigChange(DEFAULT_STDIO_CONFIG)}
|
||||
>
|
||||
Local Example
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 text-xs text-muted-foreground hover:text-foreground"
|
||||
onClick={() => handleConfigChange(DEFAULT_HTTP_CONFIG)}
|
||||
>
|
||||
Remote Example
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Textarea
|
||||
id="config"
|
||||
value={configJson}
|
||||
onChange={(e) => handleConfigChange(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Tab") {
|
||||
e.preventDefault();
|
||||
const target = e.target as HTMLTextAreaElement;
|
||||
const start = target.selectionStart;
|
||||
const end = target.selectionEnd;
|
||||
const indent = " "; // 2 spaces for JSON
|
||||
const newValue =
|
||||
configJson.substring(0, start) + indent + configJson.substring(end);
|
||||
handleConfigChange(newValue);
|
||||
// Set cursor position after the inserted tab
|
||||
requestAnimationFrame(() => {
|
||||
target.selectionStart = target.selectionEnd = start + indent.length;
|
||||
});
|
||||
}
|
||||
}}
|
||||
placeholder={DEFAULT_CONFIG}
|
||||
rows={16}
|
||||
className={`font-mono text-xs ${jsonError ? "border-red-500" : ""}`}
|
||||
/>
|
||||
{jsonError && <p className="text-xs text-red-500">JSON Error: {jsonError}</p>}
|
||||
<p className="text-[10px] sm:text-xs text-muted-foreground">
|
||||
Paste a single MCP server configuration. Must include: name, command, args (optional),
|
||||
env (optional), transport (optional).
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Test Connection */}
|
||||
<div className="pt-4">
|
||||
<Button
|
||||
type="button"
|
||||
onClick={handleTestConnection}
|
||||
disabled={isTesting}
|
||||
variant="secondary"
|
||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||
>
|
||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Test Result */}
|
||||
{testResult && (
|
||||
<Alert
|
||||
className={
|
||||
testResult.status === "success"
|
||||
? "border-green-500/50 bg-green-500/10"
|
||||
: "border-red-500/50 bg-red-500/10"
|
||||
}
|
||||
>
|
||||
{testResult.status === "success" ? (
|
||||
<CheckCircle2 className="h-4 w-4 text-green-600" />
|
||||
) : (
|
||||
<XCircle className="h-4 w-4 text-red-600" />
|
||||
)}
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<AlertTitle className="text-sm">
|
||||
{testResult.status === "success"
|
||||
? "Connection Successful"
|
||||
: "Connection Failed"}
|
||||
</AlertTitle>
|
||||
{testResult.tools.length > 0 && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 self-start sm:self-auto text-xs"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setShowDetails(!showDetails);
|
||||
}}
|
||||
>
|
||||
{showDetails ? (
|
||||
<>
|
||||
<ChevronUp className="h-3 w-3 mr-1" />
|
||||
<span className="hidden sm:inline">Hide Details</span>
|
||||
<span className="sm:hidden">Hide</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronDown className="h-3 w-3 mr-1" />
|
||||
<span className="hidden sm:inline">Show Details</span>
|
||||
<span className="sm:hidden">Show</span>
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<AlertDescription className="text-[10px] sm:text-xs mt-1">
|
||||
{testResult.message}
|
||||
{showDetails && testResult.tools.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-green-500/20">
|
||||
<p className="font-semibold mb-2">Available tools:</p>
|
||||
<ul className="list-disc list-inside text-xs space-y-0.5">
|
||||
{testResult.tools.map((tool, i) => (
|
||||
<li key={i}>{tool.name}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
</AlertDescription>
|
||||
</div>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -6,6 +6,7 @@ import { ElasticsearchConnectForm } from "./components/elasticsearch-connect-for
|
|||
import { GithubConnectForm } from "./components/github-connect-form";
|
||||
import { LinkupApiConnectForm } from "./components/linkup-api-connect-form";
|
||||
import { LumaConnectForm } from "./components/luma-connect-form";
|
||||
import { MCPConnectForm } from "./components/mcp-connect-form";
|
||||
import { SearxngConnectForm } from "./components/searxng-connect-form";
|
||||
import { TavilyApiConnectForm } from "./components/tavily-api-connect-form";
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ export interface ConnectFormProps {
|
|||
connector_type: string;
|
||||
config: Record<string, unknown>;
|
||||
is_indexable: boolean;
|
||||
is_active: boolean;
|
||||
last_indexed_at: null;
|
||||
periodic_indexing_enabled: boolean;
|
||||
indexing_frequency_minutes: number | null;
|
||||
|
|
@ -54,6 +56,8 @@ export function getConnectFormComponent(connectorType: string): ConnectFormCompo
|
|||
return LumaConnectForm;
|
||||
case "CIRCLEBACK_CONNECTOR":
|
||||
return CirclebackConnectForm;
|
||||
case "MCP_CONNECTOR":
|
||||
return MCPConnectForm;
|
||||
// Add other connector types here as needed
|
||||
default:
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
"use client";
|
||||
|
||||
import { Info } from "lucide-react";
|
||||
import { File, FileSpreadsheet, FileText, FolderClosed, Image, Presentation } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { GoogleDriveFolderTree } from "@/components/connectors/google-drive-folder-tree";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import type { ConnectorConfigProps } from "../index";
|
||||
|
||||
interface SelectedFolder {
|
||||
|
|
@ -13,128 +21,292 @@ interface SelectedFolder {
|
|||
name: string;
|
||||
}
|
||||
|
||||
interface IndexingOptions {
|
||||
max_files_per_folder: number;
|
||||
incremental_sync: boolean;
|
||||
include_subfolders: boolean;
|
||||
}
|
||||
|
||||
const DEFAULT_INDEXING_OPTIONS: IndexingOptions = {
|
||||
max_files_per_folder: 100,
|
||||
incremental_sync: true,
|
||||
include_subfolders: true,
|
||||
};
|
||||
|
||||
// Helper to get appropriate icon for file type based on file name
|
||||
function getFileIconFromName(fileName: string, className: string = "size-3.5 shrink-0") {
|
||||
const lowerName = fileName.toLowerCase();
|
||||
// Spreadsheets
|
||||
if (
|
||||
lowerName.endsWith(".xlsx") ||
|
||||
lowerName.endsWith(".xls") ||
|
||||
lowerName.endsWith(".csv") ||
|
||||
lowerName.includes("spreadsheet")
|
||||
) {
|
||||
return <FileSpreadsheet className={`${className} text-green-500`} />;
|
||||
}
|
||||
// Presentations
|
||||
if (
|
||||
lowerName.endsWith(".pptx") ||
|
||||
lowerName.endsWith(".ppt") ||
|
||||
lowerName.includes("presentation")
|
||||
) {
|
||||
return <Presentation className={`${className} text-orange-500`} />;
|
||||
}
|
||||
// Documents (word, text only - not PDF)
|
||||
if (
|
||||
lowerName.endsWith(".docx") ||
|
||||
lowerName.endsWith(".doc") ||
|
||||
lowerName.endsWith(".txt") ||
|
||||
lowerName.includes("document") ||
|
||||
lowerName.includes("word") ||
|
||||
lowerName.includes("text")
|
||||
) {
|
||||
return <FileText className={`${className} text-gray-500`} />;
|
||||
}
|
||||
// Images
|
||||
if (
|
||||
lowerName.endsWith(".png") ||
|
||||
lowerName.endsWith(".jpg") ||
|
||||
lowerName.endsWith(".jpeg") ||
|
||||
lowerName.endsWith(".gif") ||
|
||||
lowerName.endsWith(".webp") ||
|
||||
lowerName.endsWith(".svg")
|
||||
) {
|
||||
return <Image className={`${className} text-purple-500`} />;
|
||||
}
|
||||
// Default (including PDF)
|
||||
return <File className={`${className} text-gray-500`} />;
|
||||
}
|
||||
|
||||
export const GoogleDriveConfig: FC<ConnectorConfigProps> = ({ connector, onConfigChange }) => {
|
||||
// Initialize with existing selected folders and files from connector config
|
||||
const existingFolders =
|
||||
(connector.config?.selected_folders as SelectedFolder[] | undefined) || [];
|
||||
const existingFiles = (connector.config?.selected_files as SelectedFolder[] | undefined) || [];
|
||||
const existingIndexingOptions =
|
||||
(connector.config?.indexing_options as IndexingOptions | undefined) || DEFAULT_INDEXING_OPTIONS;
|
||||
|
||||
const [selectedFolders, setSelectedFolders] = useState<SelectedFolder[]>(existingFolders);
|
||||
const [selectedFiles, setSelectedFiles] = useState<SelectedFolder[]>(existingFiles);
|
||||
const [showFolderSelector, setShowFolderSelector] = useState(false);
|
||||
const [indexingOptions, setIndexingOptions] = useState<IndexingOptions>(existingIndexingOptions);
|
||||
|
||||
// Update selected folders and files when connector config changes
|
||||
useEffect(() => {
|
||||
const folders = (connector.config?.selected_folders as SelectedFolder[] | undefined) || [];
|
||||
const files = (connector.config?.selected_files as SelectedFolder[] | undefined) || [];
|
||||
const options =
|
||||
(connector.config?.indexing_options as IndexingOptions | undefined) ||
|
||||
DEFAULT_INDEXING_OPTIONS;
|
||||
setSelectedFolders(folders);
|
||||
setSelectedFiles(files);
|
||||
setIndexingOptions(options);
|
||||
}, [connector.config]);
|
||||
|
||||
const handleSelectFolders = (folders: SelectedFolder[]) => {
|
||||
setSelectedFolders(folders);
|
||||
const updateConfig = (
|
||||
folders: SelectedFolder[],
|
||||
files: SelectedFolder[],
|
||||
options: IndexingOptions
|
||||
) => {
|
||||
if (onConfigChange) {
|
||||
// Store folder IDs and names in config for indexing
|
||||
onConfigChange({
|
||||
...connector.config,
|
||||
selected_folders: folders,
|
||||
selected_files: selectedFiles, // Preserve existing files
|
||||
selected_files: files,
|
||||
indexing_options: options,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectFolders = (folders: SelectedFolder[]) => {
|
||||
setSelectedFolders(folders);
|
||||
updateConfig(folders, selectedFiles, indexingOptions);
|
||||
};
|
||||
|
||||
const handleSelectFiles = (files: SelectedFolder[]) => {
|
||||
setSelectedFiles(files);
|
||||
if (onConfigChange) {
|
||||
// Store file IDs and names in config for indexing
|
||||
onConfigChange({
|
||||
...connector.config,
|
||||
selected_folders: selectedFolders, // Preserve existing folders
|
||||
selected_files: files,
|
||||
});
|
||||
}
|
||||
updateConfig(selectedFolders, files, indexingOptions);
|
||||
};
|
||||
|
||||
const handleIndexingOptionChange = (key: keyof IndexingOptions, value: number | boolean) => {
|
||||
const newOptions = { ...indexingOptions, [key]: value };
|
||||
setIndexingOptions(newOptions);
|
||||
updateConfig(selectedFolders, selectedFiles, newOptions);
|
||||
};
|
||||
|
||||
const totalSelected = selectedFolders.length + selectedFiles.length;
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6 space-y-3 sm:space-y-4">
|
||||
<div className="space-y-1 sm:space-y-2">
|
||||
<h3 className="font-medium text-sm sm:text-base">Folder & File Selection</h3>
|
||||
<p className="text-xs sm:text-sm text-muted-foreground">
|
||||
Select specific folders and/or individual files to index. Only files directly in each
|
||||
folder will be processed—subfolders must be selected separately.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{totalSelected > 0 && (
|
||||
<div className="p-2 sm:p-3 bg-muted rounded-lg text-xs sm:text-sm space-y-1 sm:space-y-2">
|
||||
<p className="font-medium">
|
||||
Selected {totalSelected} item{totalSelected > 1 ? "s" : ""}:
|
||||
{selectedFolders.length > 0 &&
|
||||
` ${selectedFolders.length} folder${selectedFolders.length > 1 ? "s" : ""}`}
|
||||
{selectedFiles.length > 0 &&
|
||||
` ${selectedFiles.length} file${selectedFiles.length > 1 ? "s" : ""}`}
|
||||
<div className="space-y-4">
|
||||
{/* Folder & File Selection */}
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6 space-y-3 sm:space-y-4">
|
||||
<div className="space-y-1 sm:space-y-2">
|
||||
<h3 className="font-medium text-sm sm:text-base">Folder & File Selection</h3>
|
||||
<p className="text-xs sm:text-sm text-muted-foreground">
|
||||
Select specific folders and/or individual files to index.
|
||||
</p>
|
||||
<div className="max-h-20 sm:max-h-24 overflow-y-auto space-y-1">
|
||||
{selectedFolders.map((folder) => (
|
||||
<p
|
||||
key={folder.id}
|
||||
className="text-xs sm:text-sm text-muted-foreground truncate"
|
||||
title={folder.name}
|
||||
>
|
||||
📁 {folder.name}
|
||||
</p>
|
||||
))}
|
||||
{selectedFiles.map((file) => (
|
||||
<p
|
||||
key={file.id}
|
||||
className="text-xs sm:text-sm text-muted-foreground truncate"
|
||||
title={file.name}
|
||||
>
|
||||
📄 {file.name}
|
||||
</p>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showFolderSelector ? (
|
||||
<div className="space-y-2 sm:space-y-3">
|
||||
<GoogleDriveFolderTree
|
||||
connectorId={connector.id}
|
||||
selectedFolders={selectedFolders}
|
||||
onSelectFolders={handleSelectFolders}
|
||||
selectedFiles={selectedFiles}
|
||||
onSelectFiles={handleSelectFiles}
|
||||
/>
|
||||
{totalSelected > 0 && (
|
||||
<div className="p-2 sm:p-3 bg-muted rounded-lg text-xs sm:text-sm space-y-1 sm:space-y-2">
|
||||
<p className="font-medium">
|
||||
Selected {totalSelected} item{totalSelected > 1 ? "s" : ""}: {(() => {
|
||||
const parts: string[] = [];
|
||||
if (selectedFolders.length > 0) {
|
||||
parts.push(
|
||||
`${selectedFolders.length} folder${selectedFolders.length > 1 ? "s" : ""}`
|
||||
);
|
||||
}
|
||||
if (selectedFiles.length > 0) {
|
||||
parts.push(`${selectedFiles.length} file${selectedFiles.length > 1 ? "s" : ""}`);
|
||||
}
|
||||
return parts.length > 0 ? `(${parts.join(" ")})` : "";
|
||||
})()}
|
||||
</p>
|
||||
<div className="max-h-20 sm:max-h-24 overflow-y-auto space-y-1">
|
||||
{selectedFolders.map((folder) => (
|
||||
<p
|
||||
key={folder.id}
|
||||
className="text-xs sm:text-sm text-muted-foreground truncate flex items-center gap-1.5"
|
||||
title={folder.name}
|
||||
>
|
||||
<FolderClosed className="size-3.5 shrink-0 text-gray-500" />
|
||||
{folder.name}
|
||||
</p>
|
||||
))}
|
||||
{selectedFiles.map((file) => (
|
||||
<p
|
||||
key={file.id}
|
||||
className="text-xs sm:text-sm text-muted-foreground truncate flex items-center gap-1.5"
|
||||
title={file.name}
|
||||
>
|
||||
{getFileIconFromName(file.name)}
|
||||
{file.name}
|
||||
</p>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showFolderSelector ? (
|
||||
<div className="space-y-2 sm:space-y-3">
|
||||
<GoogleDriveFolderTree
|
||||
connectorId={connector.id}
|
||||
selectedFolders={selectedFolders}
|
||||
onSelectFolders={handleSelectFolders}
|
||||
selectedFiles={selectedFiles}
|
||||
onSelectFiles={handleSelectFiles}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setShowFolderSelector(false)}
|
||||
className="bg-slate-400/5 dark:bg-white/5 border-slate-400/20 hover:bg-slate-400/10 dark:hover:bg-white/10 text-xs sm:text-sm h-8 sm:h-9"
|
||||
>
|
||||
Done Selecting
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setShowFolderSelector(false)}
|
||||
onClick={() => setShowFolderSelector(true)}
|
||||
className="bg-slate-400/5 dark:bg-white/5 border-slate-400/20 hover:bg-slate-400/10 dark:hover:bg-white/10 text-xs sm:text-sm h-8 sm:h-9"
|
||||
>
|
||||
Done Selecting
|
||||
{totalSelected > 0 ? "Change Selection" : "Select Folders & Files"}
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => setShowFolderSelector(true)}
|
||||
className="bg-slate-400/5 dark:bg-white/5 border-slate-400/20 hover:bg-slate-400/10 dark:hover:bg-white/10 text-xs sm:text-sm h-8 sm:h-9"
|
||||
>
|
||||
{totalSelected > 0 ? "Change Selection" : "Select Folders & Files"}
|
||||
</Button>
|
||||
)}
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Alert className="bg-slate-400/5 dark:bg-white/5 border-slate-400/20 p-2 sm:p-3 flex items-center gap-2 [&>svg]:relative [&>svg]:left-0 [&>svg]:top-0 [&>svg+div]:translate-y-0">
|
||||
<Info className="h-3 w-3 sm:h-4 sm:w-4 shrink-0" />
|
||||
<AlertDescription className="text-[10px] sm:text-xs !pl-0">
|
||||
Folder and file selection is used when indexing. You can change this selection when you
|
||||
start indexing.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
{/* Indexing Options */}
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6 space-y-4">
|
||||
<div className="space-y-1 sm:space-y-2">
|
||||
<h3 className="font-medium text-sm sm:text-base">Indexing Options</h3>
|
||||
<p className="text-xs sm:text-sm text-muted-foreground">
|
||||
Configure how files are indexed from your Google Drive.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Max files per folder */}
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="max-files" className="text-sm font-medium">
|
||||
Max files per folder
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Maximum number of files to index from each folder
|
||||
</p>
|
||||
</div>
|
||||
<Select
|
||||
value={indexingOptions.max_files_per_folder.toString()}
|
||||
onValueChange={(value) =>
|
||||
handleIndexingOptionChange("max_files_per_folder", parseInt(value, 10))
|
||||
}
|
||||
>
|
||||
<SelectTrigger
|
||||
id="max-files"
|
||||
className="w-[140px] bg-slate-400/5 dark:bg-slate-400/5 border-slate-400/20 text-xs sm:text-sm"
|
||||
>
|
||||
<SelectValue placeholder="Select limit" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="z-[100]">
|
||||
<SelectItem value="50" className="text-xs sm:text-sm">
|
||||
50 files
|
||||
</SelectItem>
|
||||
<SelectItem value="100" className="text-xs sm:text-sm">
|
||||
100 files
|
||||
</SelectItem>
|
||||
<SelectItem value="250" className="text-xs sm:text-sm">
|
||||
250 files
|
||||
</SelectItem>
|
||||
<SelectItem value="500" className="text-xs sm:text-sm">
|
||||
500 files
|
||||
</SelectItem>
|
||||
<SelectItem value="1000" className="text-xs sm:text-sm">
|
||||
1000 files
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Incremental sync toggle */}
|
||||
<div className="flex items-center justify-between pt-2 border-t border-slate-400/20">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="incremental-sync" className="text-sm font-medium">
|
||||
Incremental sync
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Only sync changes since last index (faster). Disable for a full re-index.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="incremental-sync"
|
||||
checked={indexingOptions.incremental_sync}
|
||||
onCheckedChange={(checked) => handleIndexingOptionChange("incremental_sync", checked)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Include subfolders toggle */}
|
||||
<div className="flex items-center justify-between pt-2 border-t border-slate-400/20">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="include-subfolders" className="text-sm font-medium">
|
||||
Include subfolders
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Recursively index files in subfolders of selected folders
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="include-subfolders"
|
||||
checked={indexingOptions.include_subfolders}
|
||||
onCheckedChange={(checked) => handleIndexingOptionChange("include_subfolders", checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,293 @@
|
|||
"use client";
|
||||
|
||||
import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import type { MCPServerConfig } from "@/contracts/types/mcp.types";
|
||||
import {
|
||||
type MCPConnectionTestResult,
|
||||
parseMCPConfig,
|
||||
testMCPConnection,
|
||||
} from "../../utils/mcp-config-validator";
|
||||
import type { ConnectorConfigProps } from "../index";
|
||||
|
||||
interface MCPConfigProps extends ConnectorConfigProps {
|
||||
onNameChange?: (name: string) => void;
|
||||
}
|
||||
|
||||
export const MCPConfig: FC<MCPConfigProps> = ({ connector, onConfigChange, onNameChange }) => {
|
||||
const [name, setName] = useState<string>("");
|
||||
const [configJson, setConfigJson] = useState("");
|
||||
const [jsonError, setJsonError] = useState<string | null>(null);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [showDetails, setShowDetails] = useState(false);
|
||||
const [testResult, setTestResult] = useState<MCPConnectionTestResult | null>(null);
|
||||
const initializedRef = useRef(false);
|
||||
|
||||
// Check if this is a valid MCP connector
|
||||
const isValidConnector = connector.connector_type === EnumConnectorName.MCP_CONNECTOR;
|
||||
|
||||
// Initialize form from connector config (only on mount)
|
||||
// We intentionally only read connector.name and connector.config on initial mount
|
||||
// to preserve user edits during the session
|
||||
useEffect(() => {
|
||||
if (!isValidConnector || initializedRef.current) return;
|
||||
initializedRef.current = true;
|
||||
|
||||
if (connector.name) {
|
||||
setName(connector.name);
|
||||
}
|
||||
|
||||
const serverConfig = connector.config?.server_config as MCPServerConfig | undefined;
|
||||
if (serverConfig) {
|
||||
const transport = serverConfig.transport || "stdio";
|
||||
|
||||
// Build config object based on transport type
|
||||
let configObj: Record<string, unknown>;
|
||||
|
||||
if (transport === "streamable-http" || transport === "http" || transport === "sse") {
|
||||
// HTTP transport - use url and headers
|
||||
configObj = {
|
||||
url: (serverConfig as any).url || "",
|
||||
headers: (serverConfig as any).headers || {},
|
||||
transport: transport,
|
||||
};
|
||||
} else {
|
||||
// stdio transport (default) - use command, args, env
|
||||
configObj = {
|
||||
command: (serverConfig as any).command || "",
|
||||
args: (serverConfig as any).args || [],
|
||||
env: (serverConfig as any).env || {},
|
||||
transport: transport,
|
||||
};
|
||||
}
|
||||
|
||||
setConfigJson(JSON.stringify(configObj, null, 2));
|
||||
}
|
||||
}, [isValidConnector, connector.name, connector.config?.server_config]);
|
||||
|
||||
const handleNameChange = useCallback(
|
||||
(value: string) => {
|
||||
setName(value);
|
||||
if (onNameChange) {
|
||||
onNameChange(value);
|
||||
}
|
||||
},
|
||||
[onNameChange]
|
||||
);
|
||||
|
||||
const parseConfig = useCallback(() => {
|
||||
const result = parseMCPConfig(configJson);
|
||||
if (result.error) {
|
||||
setJsonError(result.error);
|
||||
} else {
|
||||
setJsonError(null);
|
||||
}
|
||||
return result.config;
|
||||
}, [configJson]);
|
||||
|
||||
const handleConfigChange = useCallback(
|
||||
(value: string) => {
|
||||
setConfigJson(value);
|
||||
setJsonError(null);
|
||||
|
||||
// Use shared utility for validation and parsing (with caching)
|
||||
const result = parseMCPConfig(value);
|
||||
|
||||
if (result.config && onConfigChange) {
|
||||
// Valid config - update parent immediately
|
||||
onConfigChange({ server_config: result.config });
|
||||
}
|
||||
// Ignore errors while typing - only show errors when user tests or saves
|
||||
},
|
||||
[onConfigChange]
|
||||
);
|
||||
|
||||
const handleTestConnection = useCallback(async () => {
|
||||
const serverConfig = parseConfig();
|
||||
if (!serverConfig) {
|
||||
setTestResult({
|
||||
status: "error",
|
||||
message: jsonError || "Invalid configuration",
|
||||
tools: [],
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Update parent with the config
|
||||
if (onConfigChange) {
|
||||
onConfigChange({ server_config: serverConfig });
|
||||
}
|
||||
|
||||
setIsTesting(true);
|
||||
setTestResult(null);
|
||||
|
||||
const result = await testMCPConnection(serverConfig);
|
||||
setTestResult(result);
|
||||
setIsTesting(false);
|
||||
}, [parseConfig, jsonError, onConfigChange]);
|
||||
|
||||
// Validate that this is an MCP connector - must be after all hooks
|
||||
if (!isValidConnector) {
|
||||
console.error("MCPConfig received non-MCP connector:", connector.connector_type);
|
||||
return (
|
||||
<Alert className="border-red-500/50 bg-red-500/10">
|
||||
<XCircle className="h-4 w-4 text-red-600" />
|
||||
<AlertTitle>Invalid Connector Type</AlertTitle>
|
||||
<AlertDescription>This component can only be used with MCP connectors.</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Server Name */}
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6 space-y-3 sm:space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="name" className="text-xs sm:text-sm">
|
||||
Server Name
|
||||
</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={name}
|
||||
onChange={(e) => handleNameChange(e.target.value)}
|
||||
placeholder="e.g., Filesystem Server"
|
||||
className="border-slate-400/20 focus-visible:border-slate-400/40"
|
||||
required
|
||||
/>
|
||||
<p className="text-[10px] sm:text-xs text-muted-foreground">
|
||||
A friendly name to identify this connector.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Server Configuration */}
|
||||
<div className="space-y-4">
|
||||
<h3 className="font-medium text-sm sm:text-base flex items-center gap-2">
|
||||
<Server className="h-4 w-4" />
|
||||
Server Configuration
|
||||
</h3>
|
||||
|
||||
<div className="rounded-xl border border-border bg-slate-400/5 dark:bg-white/5 p-3 sm:p-6 space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="config">MCP Server Configuration (JSON)</Label>
|
||||
<Textarea
|
||||
id="config"
|
||||
value={configJson}
|
||||
onChange={(e) => handleConfigChange(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Tab") {
|
||||
e.preventDefault();
|
||||
const target = e.target as HTMLTextAreaElement;
|
||||
const start = target.selectionStart;
|
||||
const end = target.selectionEnd;
|
||||
const indent = " "; // 2 spaces for JSON
|
||||
const newValue =
|
||||
configJson.substring(0, start) + indent + configJson.substring(end);
|
||||
handleConfigChange(newValue);
|
||||
// Set cursor position after the inserted tab
|
||||
requestAnimationFrame(() => {
|
||||
target.selectionStart = target.selectionEnd = start + indent.length;
|
||||
});
|
||||
}
|
||||
}}
|
||||
rows={16}
|
||||
className={`font-mono text-xs ${jsonError ? "border-red-500" : ""}`}
|
||||
/>
|
||||
{jsonError && <p className="text-xs text-red-500">JSON Error: {jsonError}</p>}
|
||||
<p className="text-[10px] sm:text-xs text-muted-foreground">
|
||||
<strong>Local (stdio):</strong> command, args, env, transport: "stdio"
|
||||
<br />
|
||||
<strong>Remote (HTTP):</strong> url, headers, transport: "streamable-http"
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Test Connection */}
|
||||
<div className="pt-4">
|
||||
<Button
|
||||
type="button"
|
||||
onClick={handleTestConnection}
|
||||
disabled={isTesting}
|
||||
variant="secondary"
|
||||
className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80"
|
||||
>
|
||||
{isTesting ? "Testing Connection" : "Test Connection"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Test Result */}
|
||||
{testResult && (
|
||||
<Alert
|
||||
className={
|
||||
testResult.status === "success"
|
||||
? "border-green-500/50 bg-green-500/10"
|
||||
: "border-red-500/50 bg-red-500/10"
|
||||
}
|
||||
>
|
||||
{testResult.status === "success" ? (
|
||||
<CheckCircle2 className="h-4 w-4 text-green-600" />
|
||||
) : (
|
||||
<XCircle className="h-4 w-4 text-red-600" />
|
||||
)}
|
||||
<div className="flex-1">
|
||||
<div className="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-2 sm:gap-0">
|
||||
<AlertTitle className="text-sm">
|
||||
{testResult.status === "success"
|
||||
? "Connection Successful"
|
||||
: "Connection Failed"}
|
||||
</AlertTitle>
|
||||
{testResult.tools.length > 0 && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 self-start sm:self-auto text-xs"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setShowDetails(!showDetails);
|
||||
}}
|
||||
>
|
||||
{showDetails ? (
|
||||
<>
|
||||
<ChevronUp className="h-3 w-3 mr-1" />
|
||||
<span className="hidden sm:inline">Hide Details</span>
|
||||
<span className="sm:hidden">Hide</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronDown className="h-3 w-3 mr-1" />
|
||||
<span className="hidden sm:inline">Show Details</span>
|
||||
<span className="sm:hidden">Show</span>
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<AlertDescription className="text-xs mt-1">
|
||||
{testResult.message}
|
||||
{showDetails && testResult.tools.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-green-500/20">
|
||||
<p className="font-semibold mb-2">Available tools:</p>
|
||||
<ul className="list-disc list-inside text-xs space-y-0.5">
|
||||
{testResult.tools.map((tool) => (
|
||||
<li key={tool.name}>{tool.name}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
</AlertDescription>
|
||||
</div>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -14,6 +14,7 @@ import { GoogleDriveConfig } from "./components/google-drive-config";
|
|||
import { JiraConfig } from "./components/jira-config";
|
||||
import { LinkupApiConfig } from "./components/linkup-api-config";
|
||||
import { LumaConfig } from "./components/luma-config";
|
||||
import { MCPConfig } from "./components/mcp-config";
|
||||
import { SearxngConfig } from "./components/searxng-config";
|
||||
import { SlackConfig } from "./components/slack-config";
|
||||
import { TavilyApiConfig } from "./components/tavily-api-config";
|
||||
|
|
@ -24,6 +25,7 @@ export interface ConnectorConfigProps {
|
|||
connector: SearchSourceConnector;
|
||||
onConfigChange?: (config: Record<string, unknown>) => void;
|
||||
onNameChange?: (name: string) => void;
|
||||
searchSpaceId?: string;
|
||||
}
|
||||
|
||||
export type ConnectorConfigComponent = FC<ConnectorConfigProps>;
|
||||
|
|
@ -69,6 +71,8 @@ export function getConnectorConfigComponent(
|
|||
return LumaConfig;
|
||||
case "CIRCLEBACK_CONNECTOR":
|
||||
return CirclebackConfig;
|
||||
case "MCP_CONNECTOR":
|
||||
return MCPConfig;
|
||||
// OAuth connectors (Gmail, Calendar, Airtable, Notion) and others don't need special config UI
|
||||
default:
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ export const ConnectorConnectView: FC<ConnectorConnectViewProps> = ({
|
|||
GITHUB_CONNECTOR: "github-connect-form",
|
||||
LUMA_CONNECTOR: "luma-connect-form",
|
||||
CIRCLEBACK_CONNECTOR: "circleback-connect-form",
|
||||
MCP_CONNECTOR: "mcp-connect-form",
|
||||
};
|
||||
const formId = formIdMap[connectorType];
|
||||
if (formId) {
|
||||
|
|
@ -98,7 +99,10 @@ export const ConnectorConnectView: FC<ConnectorConnectViewProps> = ({
|
|||
</div>
|
||||
<div>
|
||||
<h2 className="text-xl sm:text-2xl font-semibold tracking-tight">
|
||||
Connect {getConnectorTypeDisplay(connectorType)}
|
||||
Connect{" "}
|
||||
{connectorType === "MCP_CONNECTOR"
|
||||
? "MCP Server"
|
||||
: getConnectorTypeDisplay(connectorType)}
|
||||
</h2>
|
||||
<p className="text-xs sm:text-base text-muted-foreground mt-1">
|
||||
Enter your connection details
|
||||
|
|
@ -135,10 +139,14 @@ export const ConnectorConnectView: FC<ConnectorConnectViewProps> = ({
|
|||
{isSubmitting ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Connecting...
|
||||
Connecting
|
||||
</>
|
||||
) : (
|
||||
<>Connect {getConnectorTypeDisplay(connectorType)}</>
|
||||
<>
|
||||
{connectorType === "MCP_CONNECTOR"
|
||||
? "Connect"
|
||||
: `Connect ${getConnectorTypeDisplay(connectorType)}`}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ interface ConnectorEditViewProps {
|
|||
isSaving: boolean;
|
||||
isDisconnecting: boolean;
|
||||
isIndexing?: boolean;
|
||||
searchSpaceId?: string;
|
||||
onStartDateChange: (date: Date | undefined) => void;
|
||||
onEndDateChange: (date: Date | undefined) => void;
|
||||
onPeriodicEnabledChange: (enabled: boolean) => void;
|
||||
|
|
@ -40,6 +41,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
isSaving,
|
||||
isDisconnecting,
|
||||
isIndexing = false,
|
||||
searchSpaceId,
|
||||
onStartDateChange,
|
||||
onEndDateChange,
|
||||
onPeriodicEnabledChange,
|
||||
|
|
@ -170,7 +172,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
{isQuickIndexing || isIndexing ? (
|
||||
<>
|
||||
<RefreshCw className="mr-2 h-4 w-4 animate-spin" />
|
||||
Indexing...
|
||||
Syncing
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
|
|
@ -197,6 +199,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
connector={connector}
|
||||
onConfigChange={onConfigChange}
|
||||
onNameChange={onNameChange}
|
||||
searchSpaceId={searchSpaceId}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
|
@ -218,15 +221,36 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
/>
|
||||
)}
|
||||
|
||||
{/* Periodic sync - not shown for Google Drive */}
|
||||
{connector.connector_type !== "GOOGLE_DRIVE_CONNECTOR" && (
|
||||
<PeriodicSyncConfig
|
||||
enabled={periodicEnabled}
|
||||
frequencyMinutes={frequencyMinutes}
|
||||
onEnabledChange={onPeriodicEnabledChange}
|
||||
onFrequencyChange={onFrequencyChange}
|
||||
/>
|
||||
)}
|
||||
{/* Periodic sync - shown for all indexable connectors */}
|
||||
{(() => {
|
||||
// Check if Google Drive has folders/files selected
|
||||
const isGoogleDrive = connector.connector_type === "GOOGLE_DRIVE_CONNECTOR";
|
||||
const selectedFolders =
|
||||
(connector.config?.selected_folders as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined) || [];
|
||||
const selectedFiles =
|
||||
(connector.config?.selected_files as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined) || [];
|
||||
const hasItemsSelected = selectedFolders.length > 0 || selectedFiles.length > 0;
|
||||
const isDisabled = isGoogleDrive && !hasItemsSelected;
|
||||
|
||||
return (
|
||||
<PeriodicSyncConfig
|
||||
enabled={periodicEnabled}
|
||||
frequencyMinutes={frequencyMinutes}
|
||||
onEnabledChange={onPeriodicEnabledChange}
|
||||
onFrequencyChange={onFrequencyChange}
|
||||
disabled={isDisabled}
|
||||
disabledMessage={
|
||||
isDisabled
|
||||
? "Select at least one folder or file above to enable periodic sync"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
})()}
|
||||
</>
|
||||
)}
|
||||
|
||||
|
|
@ -277,7 +301,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
{isDisconnecting ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Disconnecting...
|
||||
Disconnecting
|
||||
</>
|
||||
) : (
|
||||
"Confirm Disconnect"
|
||||
|
|
|
|||
|
|
@ -160,6 +160,12 @@ export const OTHER_CONNECTORS = [
|
|||
description: "Receive meeting notes, transcripts",
|
||||
connectorType: EnumConnectorName.CIRCLEBACK_CONNECTOR,
|
||||
},
|
||||
{
|
||||
id: "mcp-connector",
|
||||
title: "MCPs",
|
||||
description: "Connect to MCP servers for AI tools",
|
||||
connectorType: EnumConnectorName.MCP_CONNECTOR,
|
||||
},
|
||||
] as const;
|
||||
|
||||
// Re-export IndexingConfigState from schemas for backward compatibility
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import { searchSourceConnectorTypeEnum } from "@/contracts/types/connector.types
|
|||
export const connectorPopupQueryParamsSchema = z.object({
|
||||
modal: z.enum(["connectors"]).optional(),
|
||||
tab: z.enum(["all", "active"]).optional(),
|
||||
view: z.enum(["configure", "edit", "connect", "youtube", "accounts"]).optional(),
|
||||
view: z.enum(["configure", "edit", "connect", "youtube", "accounts", "mcp-list"]).optional(),
|
||||
connector: z.string().optional(),
|
||||
connectorId: z.string().optional(),
|
||||
connectorType: z.string().optional(),
|
||||
|
|
|
|||
|
|
@ -80,12 +80,18 @@ export const useConnectorDialog = () => {
|
|||
connectorTitle: string;
|
||||
} | null>(null);
|
||||
|
||||
// MCP list view state (for managing multiple MCP connectors)
|
||||
const [viewingMCPList, setViewingMCPList] = useState(false);
|
||||
|
||||
// Track if we came from accounts list when entering edit mode
|
||||
const [cameFromAccountsList, setCameFromAccountsList] = useState<{
|
||||
connectorType: string;
|
||||
connectorTitle: string;
|
||||
} | null>(null);
|
||||
|
||||
// Track if we came from MCP list view when entering edit mode
|
||||
const [cameFromMCPList, setCameFromMCPList] = useState(false);
|
||||
|
||||
// Helper function to get frequency label
|
||||
const getFrequencyLabel = useCallback((minutes: string): string => {
|
||||
switch (minutes) {
|
||||
|
|
@ -139,6 +145,16 @@ export const useConnectorDialog = () => {
|
|||
setViewingAccountsType(null);
|
||||
}
|
||||
|
||||
// Clear MCP list view if view is not "mcp-list" anymore
|
||||
if (params.view !== "mcp-list" && viewingMCPList) {
|
||||
setViewingMCPList(false);
|
||||
}
|
||||
|
||||
// Handle MCP list view
|
||||
if (params.view === "mcp-list" && !viewingMCPList) {
|
||||
setViewingMCPList(true);
|
||||
}
|
||||
|
||||
// Handle connect view
|
||||
if (params.view === "connect" && params.connectorType && !connectingConnectorType) {
|
||||
setConnectingConnectorType(params.connectorType);
|
||||
|
|
@ -203,11 +219,9 @@ export const useConnectorDialog = () => {
|
|||
setEditingConnector(connector);
|
||||
setConnectorConfig(connector.config);
|
||||
setConnectorName(connector.name);
|
||||
// Load existing periodic sync settings (disabled for Google Drive and non-indexable connectors)
|
||||
// Load existing periodic sync settings (disabled for non-indexable connectors)
|
||||
setPeriodicEnabled(
|
||||
connector.connector_type === "GOOGLE_DRIVE_CONNECTOR" || !connector.is_indexable
|
||||
? false
|
||||
: connector.periodic_indexing_enabled
|
||||
!connector.is_indexable ? false : connector.periodic_indexing_enabled
|
||||
);
|
||||
setFrequencyMinutes(connector.indexing_frequency_minutes?.toString() || "1440");
|
||||
// Reset dates - user can set new ones for re-indexing
|
||||
|
|
@ -421,6 +435,7 @@ export const useConnectorDialog = () => {
|
|||
connector_type: EnumConnectorName.WEBCRAWLER_CONNECTOR,
|
||||
config: {},
|
||||
is_indexable: true,
|
||||
is_active: true,
|
||||
last_indexed_at: null,
|
||||
periodic_indexing_enabled: false,
|
||||
indexing_frequency_minutes: null,
|
||||
|
|
@ -491,20 +506,23 @@ export const useConnectorDialog = () => {
|
|||
|
||||
// Handle submitting connect form
|
||||
const handleSubmitConnectForm = useCallback(
|
||||
async (formData: {
|
||||
name: string;
|
||||
connector_type: string;
|
||||
config: Record<string, unknown>;
|
||||
is_indexable: boolean;
|
||||
last_indexed_at: null;
|
||||
periodic_indexing_enabled: boolean;
|
||||
indexing_frequency_minutes: number | null;
|
||||
next_scheduled_at: null;
|
||||
startDate?: Date;
|
||||
endDate?: Date;
|
||||
periodicEnabled?: boolean;
|
||||
frequencyMinutes?: string;
|
||||
}) => {
|
||||
async (
|
||||
formData: {
|
||||
name: string;
|
||||
connector_type: string;
|
||||
config: Record<string, unknown>;
|
||||
is_indexable: boolean;
|
||||
last_indexed_at: null;
|
||||
periodic_indexing_enabled: boolean;
|
||||
indexing_frequency_minutes: number | null;
|
||||
next_scheduled_at: null;
|
||||
startDate?: Date;
|
||||
endDate?: Date;
|
||||
periodicEnabled?: boolean;
|
||||
frequencyMinutes?: string;
|
||||
},
|
||||
onIndexingStart?: (connectorId: number) => void
|
||||
) => {
|
||||
if (!searchSpaceId || !connectingConnectorType) return;
|
||||
|
||||
// Prevent multiple submissions using ref for immediate check
|
||||
|
|
@ -522,6 +540,7 @@ export const useConnectorDialog = () => {
|
|||
data: {
|
||||
...connectorData,
|
||||
connector_type: connectorData.connector_type as EnumConnectorName,
|
||||
is_active: true,
|
||||
next_scheduled_at: connectorData.next_scheduled_at as string | null,
|
||||
},
|
||||
queryParams: {
|
||||
|
|
@ -603,6 +622,11 @@ export const useConnectorDialog = () => {
|
|||
});
|
||||
}
|
||||
|
||||
// Notify caller that indexing is starting (for UI syncing state)
|
||||
if (onIndexingStart) {
|
||||
onIndexingStart(connector.id);
|
||||
}
|
||||
|
||||
// Start indexing (backend will use defaults if dates are undefined)
|
||||
const startDateStr = startDateForIndexing
|
||||
? format(startDateForIndexing, "yyyy-MM-dd")
|
||||
|
|
@ -620,13 +644,16 @@ export const useConnectorDialog = () => {
|
|||
},
|
||||
});
|
||||
|
||||
toast.success(`${connectorTitle} connected and indexing started!`, {
|
||||
const successMessage =
|
||||
currentConnectorType === "MCP_CONNECTOR"
|
||||
? `${connector.name} added successfully`
|
||||
: `${connectorTitle} connected and indexing started!`;
|
||||
toast.success(successMessage, {
|
||||
description: periodicEnabledForIndexing
|
||||
? `Periodic sync enabled every ${getFrequencyLabel(frequencyMinutesForIndexing)}.`
|
||||
: "You can continue working while we sync your data.",
|
||||
});
|
||||
|
||||
// Close modal and return to main view
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.delete("modal");
|
||||
url.searchParams.delete("tab");
|
||||
|
|
@ -682,7 +709,14 @@ export const useConnectorDialog = () => {
|
|||
await refetchAllConnectors();
|
||||
} else {
|
||||
// Other non-indexable connectors - just show success message and close
|
||||
toast.success(`${connectorTitle} connected successfully!`);
|
||||
const successMessage =
|
||||
currentConnectorType === "MCP_CONNECTOR"
|
||||
? `${connector.name} added successfully`
|
||||
: `${connectorTitle} connected successfully!`;
|
||||
toast.success(successMessage);
|
||||
|
||||
// Refresh connectors list before closing modal
|
||||
await refetchAllConnectors();
|
||||
|
||||
// Close modal and return to main view
|
||||
const url = new URL(window.location.href);
|
||||
|
|
@ -726,11 +760,18 @@ export const useConnectorDialog = () => {
|
|||
const handleBackFromConnect = useCallback(() => {
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("tab", "all");
|
||||
url.searchParams.delete("view");
|
||||
|
||||
// If we're connecting an MCP and came from list view, go back to list
|
||||
if (connectingConnectorType === "MCP_CONNECTOR" && viewingMCPList) {
|
||||
url.searchParams.set("view", "mcp-list");
|
||||
} else {
|
||||
url.searchParams.set("tab", "all");
|
||||
url.searchParams.delete("view");
|
||||
}
|
||||
|
||||
url.searchParams.delete("connectorType");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
}, [router]);
|
||||
}, [router, connectingConnectorType, viewingMCPList]);
|
||||
|
||||
// Handle going back from YouTube view
|
||||
const handleBackFromYouTube = useCallback(() => {
|
||||
|
|
@ -773,6 +814,38 @@ export const useConnectorDialog = () => {
|
|||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
}, [router]);
|
||||
|
||||
// Handle viewing MCP list
|
||||
const handleViewMCPList = useCallback(() => {
|
||||
if (!searchSpaceId) return;
|
||||
|
||||
setViewingMCPList(true);
|
||||
|
||||
// Update URL to show MCP list view
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("view", "mcp-list");
|
||||
window.history.pushState({ modal: true }, "", url.toString());
|
||||
}, [searchSpaceId]);
|
||||
|
||||
// Handle going back from MCP list view
|
||||
const handleBackFromMCPList = useCallback(() => {
|
||||
setViewingMCPList(false);
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.delete("view");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
}, [router]);
|
||||
|
||||
// Handle adding new MCP from list view
|
||||
const handleAddNewMCPFromList = useCallback(() => {
|
||||
setConnectingConnectorType("MCP_CONNECTOR");
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("view", "connect");
|
||||
url.searchParams.set("connectorType", "MCP_CONNECTOR");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
}, [router]);
|
||||
|
||||
// Handle starting indexing
|
||||
const handleStartIndexing = useCallback(
|
||||
async (refreshConnectors: () => void) => {
|
||||
|
|
@ -809,20 +882,14 @@ export const useConnectorDialog = () => {
|
|||
const endDateStr = endDate ? format(endDate, "yyyy-MM-dd") : undefined;
|
||||
|
||||
// Update connector with periodic sync settings and config changes
|
||||
// Note: Periodic sync is disabled for Google Drive connectors
|
||||
if (periodicEnabled || indexingConnectorConfig) {
|
||||
const frequency = periodicEnabled ? parseInt(frequencyMinutes, 10) : undefined;
|
||||
await updateConnector({
|
||||
id: indexingConfig.connectorId,
|
||||
data: {
|
||||
...(periodicEnabled &&
|
||||
indexingConfig.connectorType !== "GOOGLE_DRIVE_CONNECTOR" && {
|
||||
periodic_indexing_enabled: true,
|
||||
indexing_frequency_minutes: frequency,
|
||||
}),
|
||||
...(indexingConfig.connectorType === "GOOGLE_DRIVE_CONNECTOR" && {
|
||||
periodic_indexing_enabled: false,
|
||||
indexing_frequency_minutes: null,
|
||||
...(periodicEnabled && {
|
||||
periodic_indexing_enabled: true,
|
||||
indexing_frequency_minutes: frequency,
|
||||
}),
|
||||
...(indexingConnectorConfig && {
|
||||
config: indexingConnectorConfig,
|
||||
|
|
@ -839,11 +906,18 @@ export const useConnectorDialog = () => {
|
|||
const selectedFiles = indexingConnectorConfig.selected_files as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined;
|
||||
const indexingOptions = indexingConnectorConfig.indexing_options as
|
||||
| {
|
||||
max_files_per_folder: number;
|
||||
incremental_sync: boolean;
|
||||
include_subfolders: boolean;
|
||||
}
|
||||
| undefined;
|
||||
if (
|
||||
(selectedFolders && selectedFolders.length > 0) ||
|
||||
(selectedFiles && selectedFiles.length > 0)
|
||||
) {
|
||||
// Index with folder/file selection
|
||||
// Index with folder/file selection and indexing options
|
||||
await indexConnector({
|
||||
connector_id: indexingConfig.connectorId,
|
||||
queryParams: {
|
||||
|
|
@ -852,6 +926,11 @@ export const useConnectorDialog = () => {
|
|||
body: {
|
||||
folders: selectedFolders || [],
|
||||
files: selectedFiles || [],
|
||||
indexing_options: indexingOptions || {
|
||||
max_files_per_folder: 100,
|
||||
incremental_sync: true,
|
||||
include_subfolders: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
|
|
@ -891,7 +970,7 @@ export const useConnectorDialog = () => {
|
|||
);
|
||||
|
||||
// Track periodic indexing started if enabled
|
||||
if (periodicEnabled && indexingConfig.connectorType !== "GOOGLE_DRIVE_CONNECTOR") {
|
||||
if (periodicEnabled) {
|
||||
trackPeriodicIndexingStarted(
|
||||
Number(searchSpaceId),
|
||||
indexingConfig.connectorType,
|
||||
|
|
@ -958,6 +1037,13 @@ export const useConnectorDialog = () => {
|
|||
(connector: SearchSourceConnector) => {
|
||||
if (!searchSpaceId) return;
|
||||
|
||||
// For MCP connectors from "All Connectors" tab, show the list view instead of directly editing
|
||||
// (unless we're already in the MCP list view or on the Active tab where individual MCPs are shown)
|
||||
if (connector.connector_type === "MCP_CONNECTOR" && !viewingMCPList && activeTab === "all") {
|
||||
handleViewMCPList();
|
||||
return;
|
||||
}
|
||||
|
||||
// All connector types should be handled in the popup edit view
|
||||
// Validate connector data
|
||||
const connectorValidation = searchSourceConnector.safeParse(connector);
|
||||
|
|
@ -974,6 +1060,13 @@ export const useConnectorDialog = () => {
|
|||
setCameFromAccountsList(null);
|
||||
}
|
||||
|
||||
// Track if we came from MCP list view
|
||||
if (viewingMCPList && connector.connector_type === "MCP_CONNECTOR") {
|
||||
setCameFromMCPList(true);
|
||||
} else {
|
||||
setCameFromMCPList(false);
|
||||
}
|
||||
|
||||
// Track index with date range opened event
|
||||
if (connector.is_indexable) {
|
||||
trackIndexWithDateRangeOpened(
|
||||
|
|
@ -985,12 +1078,8 @@ export const useConnectorDialog = () => {
|
|||
|
||||
setEditingConnector(connector);
|
||||
setConnectorName(connector.name);
|
||||
// Load existing periodic sync settings (disabled for Google Drive and non-indexable connectors)
|
||||
setPeriodicEnabled(
|
||||
connector.connector_type === "GOOGLE_DRIVE_CONNECTOR" || !connector.is_indexable
|
||||
? false
|
||||
: connector.periodic_indexing_enabled
|
||||
);
|
||||
// Load existing periodic sync settings (disabled for non-indexable connectors)
|
||||
setPeriodicEnabled(!connector.is_indexable ? false : connector.periodic_indexing_enabled);
|
||||
setFrequencyMinutes(connector.indexing_frequency_minutes?.toString() || "1440");
|
||||
// Reset dates - user can set new ones for re-indexing
|
||||
setStartDate(undefined);
|
||||
|
|
@ -1003,13 +1092,13 @@ export const useConnectorDialog = () => {
|
|||
url.searchParams.set("connectorId", connector.id.toString());
|
||||
window.history.pushState({ modal: true }, "", url.toString());
|
||||
},
|
||||
[searchSpaceId, viewingAccountsType]
|
||||
[searchSpaceId, viewingAccountsType, viewingMCPList, handleViewMCPList, activeTab]
|
||||
);
|
||||
|
||||
// Handle saving connector changes
|
||||
const handleSaveConnector = useCallback(
|
||||
async (refreshConnectors: () => void) => {
|
||||
if (!editingConnector || !searchSpaceId) return;
|
||||
if (!editingConnector || !searchSpaceId || isSaving) return;
|
||||
|
||||
// Validate date range (skip for Google Drive which uses folder selection, Webcrawler which uses config, and non-indexable connectors)
|
||||
if (
|
||||
|
|
@ -1030,6 +1119,24 @@ export const useConnectorDialog = () => {
|
|||
return;
|
||||
}
|
||||
|
||||
// Prevent periodic indexing for Google Drive without folders/files selected
|
||||
if (periodicEnabled && editingConnector.connector_type === "GOOGLE_DRIVE_CONNECTOR") {
|
||||
const selectedFolders = (connectorConfig || editingConnector.config)?.selected_folders as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined;
|
||||
const selectedFiles = (connectorConfig || editingConnector.config)?.selected_files as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined;
|
||||
const hasItemsSelected =
|
||||
(selectedFolders && selectedFolders.length > 0) ||
|
||||
(selectedFiles && selectedFiles.length > 0);
|
||||
|
||||
if (!hasItemsSelected) {
|
||||
toast.error("Select at least one folder or file to enable periodic sync");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate frequency minutes if periodic is enabled (only for indexable connectors)
|
||||
if (periodicEnabled && editingConnector.is_indexable) {
|
||||
const frequencyValidation = frequencyMinutesSchema.safeParse(frequencyMinutes);
|
||||
|
|
@ -1045,23 +1152,14 @@ export const useConnectorDialog = () => {
|
|||
const endDateStr = endDate ? format(endDate, "yyyy-MM-dd") : undefined;
|
||||
|
||||
// Update connector with periodic sync settings, config changes, and name
|
||||
// Note: Periodic sync is disabled for Google Drive connectors and non-indexable connectors
|
||||
const frequency =
|
||||
periodicEnabled && editingConnector.is_indexable ? parseInt(frequencyMinutes, 10) : null;
|
||||
await updateConnector({
|
||||
id: editingConnector.id,
|
||||
data: {
|
||||
name: connectorName || editingConnector.name,
|
||||
periodic_indexing_enabled:
|
||||
editingConnector.connector_type === "GOOGLE_DRIVE_CONNECTOR" ||
|
||||
!editingConnector.is_indexable
|
||||
? false
|
||||
: periodicEnabled,
|
||||
indexing_frequency_minutes:
|
||||
editingConnector.connector_type === "GOOGLE_DRIVE_CONNECTOR" ||
|
||||
!editingConnector.is_indexable
|
||||
? null
|
||||
: frequency,
|
||||
periodic_indexing_enabled: !editingConnector.is_indexable ? false : periodicEnabled,
|
||||
indexing_frequency_minutes: !editingConnector.is_indexable ? null : frequency,
|
||||
config: connectorConfig || editingConnector.config,
|
||||
},
|
||||
});
|
||||
|
|
@ -1079,6 +1177,13 @@ export const useConnectorDialog = () => {
|
|||
const selectedFiles = (connectorConfig || editingConnector.config)?.selected_files as
|
||||
| Array<{ id: string; name: string }>
|
||||
| undefined;
|
||||
const indexingOptions = (connectorConfig || editingConnector.config)?.indexing_options as
|
||||
| {
|
||||
max_files_per_folder: number;
|
||||
incremental_sync: boolean;
|
||||
include_subfolders: boolean;
|
||||
}
|
||||
| undefined;
|
||||
if (
|
||||
(selectedFolders && selectedFolders.length > 0) ||
|
||||
(selectedFiles && selectedFiles.length > 0)
|
||||
|
|
@ -1091,6 +1196,11 @@ export const useConnectorDialog = () => {
|
|||
body: {
|
||||
folders: selectedFolders || [],
|
||||
files: selectedFiles || [],
|
||||
indexing_options: indexingOptions || {
|
||||
max_files_per_folder: 100,
|
||||
incremental_sync: true,
|
||||
include_subfolders: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
const totalItems = (selectedFolders?.length || 0) + (selectedFiles?.length || 0);
|
||||
|
|
@ -1134,12 +1244,8 @@ export const useConnectorDialog = () => {
|
|||
);
|
||||
}
|
||||
|
||||
// Track periodic indexing if enabled (for non-Google Drive connectors)
|
||||
if (
|
||||
periodicEnabled &&
|
||||
editingConnector.is_indexable &&
|
||||
editingConnector.connector_type !== "GOOGLE_DRIVE_CONNECTOR"
|
||||
) {
|
||||
// Track periodic indexing if enabled
|
||||
if (periodicEnabled && editingConnector.is_indexable) {
|
||||
trackPeriodicIndexingStarted(
|
||||
Number(searchSpaceId),
|
||||
editingConnector.connector_type,
|
||||
|
|
@ -1148,7 +1254,10 @@ export const useConnectorDialog = () => {
|
|||
);
|
||||
}
|
||||
|
||||
toast.success(`${editingConnector.name} updated successfully`, {
|
||||
// Generate toast message based on connector type
|
||||
const toastTitle = `${editingConnector.name} updated successfully`;
|
||||
|
||||
toast.success(toastTitle, {
|
||||
description: periodicEnabled
|
||||
? `Periodic sync ${frequency ? `enabled every ${getFrequencyLabel(frequencyMinutes)}` : "enabled"}. ${indexingDescription}`
|
||||
: indexingDescription,
|
||||
|
|
@ -1176,6 +1285,7 @@ export const useConnectorDialog = () => {
|
|||
[
|
||||
editingConnector,
|
||||
searchSpaceId,
|
||||
isSaving,
|
||||
startDate,
|
||||
endDate,
|
||||
indexConnector,
|
||||
|
|
@ -1207,14 +1317,27 @@ export const useConnectorDialog = () => {
|
|||
editingConnector.id
|
||||
);
|
||||
|
||||
toast.success(`${editingConnector.name} disconnected successfully`);
|
||||
toast.success(
|
||||
editingConnector.connector_type === "MCP_CONNECTOR"
|
||||
? `${editingConnector.name} MCP server removed successfully`
|
||||
: `${editingConnector.name} disconnected successfully`
|
||||
);
|
||||
|
||||
// Update URL - the effect will handle closing the modal and clearing state
|
||||
// Update URL - for MCP from list view, go back to list; otherwise close modal
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.delete("modal");
|
||||
url.searchParams.delete("tab");
|
||||
url.searchParams.delete("view");
|
||||
url.searchParams.delete("connectorId");
|
||||
if (editingConnector.connector_type === "MCP_CONNECTOR" && cameFromMCPList) {
|
||||
// Go back to MCP list view only if we came from there
|
||||
setViewingMCPList(true);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("view", "mcp-list");
|
||||
url.searchParams.delete("connectorId");
|
||||
} else {
|
||||
// Close modal for all other cases
|
||||
url.searchParams.delete("modal");
|
||||
url.searchParams.delete("tab");
|
||||
url.searchParams.delete("view");
|
||||
url.searchParams.delete("connectorId");
|
||||
}
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
|
||||
refreshConnectors();
|
||||
|
|
@ -1266,6 +1389,21 @@ export const useConnectorDialog = () => {
|
|||
|
||||
// Handle going back from edit view
|
||||
const handleBackFromEdit = useCallback(() => {
|
||||
// If editing an MCP connector and came from MCP list, go back to MCP list view
|
||||
if (editingConnector?.connector_type === "MCP_CONNECTOR" && cameFromMCPList) {
|
||||
setViewingMCPList(true);
|
||||
setCameFromMCPList(false);
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("view", "mcp-list");
|
||||
url.searchParams.delete("connectorId");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
setEditingConnector(null);
|
||||
setConnectorName(null);
|
||||
setConnectorConfig(null);
|
||||
return;
|
||||
}
|
||||
|
||||
// If we came from accounts list view, go back there
|
||||
if (cameFromAccountsList && editingConnector) {
|
||||
// Restore accounts list view
|
||||
|
|
@ -1278,10 +1416,10 @@ export const useConnectorDialog = () => {
|
|||
url.searchParams.delete("connectorId");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
} else {
|
||||
// Otherwise, go back to main connector popup
|
||||
// Otherwise, go back to main connector popup (preserve the tab the user was on)
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("modal", "connectors");
|
||||
url.searchParams.set("tab", "all");
|
||||
url.searchParams.set("tab", activeTab); // Use current tab instead of always "all"
|
||||
url.searchParams.delete("view");
|
||||
url.searchParams.delete("connectorId");
|
||||
router.replace(url.pathname + url.search, { scroll: false });
|
||||
|
|
@ -1289,7 +1427,7 @@ export const useConnectorDialog = () => {
|
|||
setEditingConnector(null);
|
||||
setConnectorName(null);
|
||||
setConnectorConfig(null);
|
||||
}, [router, cameFromAccountsList, editingConnector]);
|
||||
}, [router, cameFromAccountsList, editingConnector, cameFromMCPList, activeTab]);
|
||||
|
||||
// Handle dialog open/close
|
||||
const handleOpenChange = useCallback(
|
||||
|
|
@ -1367,6 +1505,7 @@ export const useConnectorDialog = () => {
|
|||
searchSpaceId,
|
||||
allConnectors,
|
||||
viewingAccountsType,
|
||||
viewingMCPList,
|
||||
|
||||
// Setters
|
||||
setSearchQuery,
|
||||
|
|
@ -1395,6 +1534,9 @@ export const useConnectorDialog = () => {
|
|||
handleBackFromYouTube,
|
||||
handleViewAccountsList,
|
||||
handleBackFromAccountsList,
|
||||
handleViewMCPList,
|
||||
handleBackFromMCPList,
|
||||
handleAddNewMCPFromList,
|
||||
handleQuickIndexConnector,
|
||||
connectorConfig,
|
||||
setConnectorConfig,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import {
|
||||
type ConnectorStatusConfig,
|
||||
connectorStatusConfig,
|
||||
|
|
@ -14,34 +14,43 @@ export function useConnectorStatus() {
|
|||
/**
|
||||
* Get status configuration for a specific connector type
|
||||
*/
|
||||
const getConnectorStatus = (connectorType: string | undefined): ConnectorStatusConfig => {
|
||||
if (!connectorType) {
|
||||
return getDefaultConnectorStatus();
|
||||
}
|
||||
const getConnectorStatus = useCallback(
|
||||
(connectorType: string | undefined): ConnectorStatusConfig => {
|
||||
if (!connectorType) {
|
||||
return getDefaultConnectorStatus();
|
||||
}
|
||||
|
||||
return connectorStatusConfig.connectorStatuses[connectorType] || getDefaultConnectorStatus();
|
||||
};
|
||||
return connectorStatusConfig.connectorStatuses[connectorType] || getDefaultConnectorStatus();
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
* Check if a connector is enabled
|
||||
*/
|
||||
const isConnectorEnabled = (connectorType: string | undefined): boolean => {
|
||||
return getConnectorStatus(connectorType).enabled;
|
||||
};
|
||||
const isConnectorEnabled = useCallback(
|
||||
(connectorType: string | undefined): boolean => {
|
||||
return getConnectorStatus(connectorType).enabled;
|
||||
},
|
||||
[getConnectorStatus]
|
||||
);
|
||||
|
||||
/**
|
||||
* Get status message for a connector
|
||||
*/
|
||||
const getConnectorStatusMessage = (connectorType: string | undefined): string | null => {
|
||||
return getConnectorStatus(connectorType).statusMessage || null;
|
||||
};
|
||||
const getConnectorStatusMessage = useCallback(
|
||||
(connectorType: string | undefined): string | null => {
|
||||
return getConnectorStatus(connectorType).statusMessage || null;
|
||||
},
|
||||
[getConnectorStatus]
|
||||
);
|
||||
|
||||
/**
|
||||
* Check if warnings should be shown globally
|
||||
*/
|
||||
const shouldShowWarnings = (): boolean => {
|
||||
const shouldShowWarnings = useCallback((): boolean => {
|
||||
return connectorStatusConfig.globalSettings.showWarnings;
|
||||
};
|
||||
}, []);
|
||||
|
||||
return useMemo(
|
||||
() => ({
|
||||
|
|
@ -50,6 +59,6 @@ export function useConnectorStatus() {
|
|||
getConnectorStatusMessage,
|
||||
shouldShowWarnings,
|
||||
}),
|
||||
[]
|
||||
[getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings]
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
|
||||
/**
|
||||
* Hook to track which connectors are currently indexing using local state.
|
||||
*
|
||||
* This provides a better UX than polling by:
|
||||
* 1. Setting indexing state immediately when user triggers indexing (optimistic)
|
||||
* 2. Clearing indexing state when Electric SQL detects last_indexed_at changed
|
||||
*
|
||||
* The actual `last_indexed_at` value comes from Electric SQL/PGlite, not local state.
|
||||
*/
|
||||
export function useIndexingConnectors(connectors: SearchSourceConnector[]) {
|
||||
// Set of connector IDs that are currently indexing
|
||||
const [indexingConnectorIds, setIndexingConnectorIds] = useState<Set<number>>(new Set());
|
||||
|
||||
// Track previous last_indexed_at values to detect changes
|
||||
const previousLastIndexedAtRef = useRef<Map<number, string | null>>(new Map());
|
||||
|
||||
// Detect when last_indexed_at changes (indexing completed) via Electric SQL
|
||||
useEffect(() => {
|
||||
const previousValues = previousLastIndexedAtRef.current;
|
||||
const newIndexingIds = new Set(indexingConnectorIds);
|
||||
let hasChanges = false;
|
||||
|
||||
for (const connector of connectors) {
|
||||
const previousValue = previousValues.get(connector.id);
|
||||
const currentValue = connector.last_indexed_at;
|
||||
|
||||
// If last_indexed_at changed and connector was in indexing state, clear it
|
||||
if (
|
||||
previousValue !== undefined && // We've seen this connector before
|
||||
previousValue !== currentValue && // Value changed
|
||||
indexingConnectorIds.has(connector.id) // It was marked as indexing
|
||||
) {
|
||||
newIndexingIds.delete(connector.id);
|
||||
hasChanges = true;
|
||||
}
|
||||
|
||||
// Update previous value tracking
|
||||
previousValues.set(connector.id, currentValue);
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
setIndexingConnectorIds(newIndexingIds);
|
||||
}
|
||||
}, [connectors, indexingConnectorIds]);
|
||||
|
||||
// Add a connector to the indexing set (called when indexing starts)
|
||||
const startIndexing = useCallback((connectorId: number) => {
|
||||
setIndexingConnectorIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.add(connectorId);
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Remove a connector from the indexing set (called manually if needed)
|
||||
const stopIndexing = useCallback((connectorId: number) => {
|
||||
setIndexingConnectorIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.delete(connectorId);
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Check if a connector is currently indexing
|
||||
const isIndexing = useCallback(
|
||||
(connectorId: number) => indexingConnectorIds.has(connectorId),
|
||||
[indexingConnectorIds]
|
||||
);
|
||||
|
||||
return {
|
||||
indexingConnectorIds,
|
||||
startIndexing,
|
||||
stopIndexing,
|
||||
isIndexing,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,15 +1,17 @@
|
|||
"use client";
|
||||
|
||||
import { differenceInDays, differenceInMinutes, format, isToday, isYesterday } from "date-fns";
|
||||
import { ArrowRight, Cable, Loader2 } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { FC } from "react";
|
||||
import { useState } from "react";
|
||||
import { getDocumentTypeLabel } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { TabsContent } from "@/components/ui/tabs";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import type { LogActiveTask, LogSummary } from "@/contracts/types/log.types";
|
||||
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { OAUTH_CONNECTORS } from "../constants/connector-constants";
|
||||
import { getDocumentCountForConnector } from "../utils/connector-document-mapping";
|
||||
|
|
@ -21,20 +23,26 @@ interface ActiveConnectorsTabProps {
|
|||
activeDocumentTypes: Array<[string, number]>;
|
||||
connectors: SearchSourceConnector[];
|
||||
indexingConnectorIds: Set<number>;
|
||||
logsSummary: LogSummary | undefined;
|
||||
searchSpaceId: string;
|
||||
onTabChange: (value: string) => void;
|
||||
onManage?: (connector: SearchSourceConnector) => void;
|
||||
onViewAccountsList?: (connectorType: string, connectorTitle: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a connector type is indexable
|
||||
*/
|
||||
function isIndexableConnector(connectorType: string): boolean {
|
||||
const nonIndexableTypes = ["MCP_CONNECTOR"];
|
||||
return !nonIndexableTypes.includes(connectorType);
|
||||
}
|
||||
|
||||
export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
||||
searchQuery,
|
||||
hasSources,
|
||||
activeDocumentTypes,
|
||||
connectors,
|
||||
indexingConnectorIds,
|
||||
logsSummary,
|
||||
searchSpaceId,
|
||||
onTabChange,
|
||||
onManage,
|
||||
|
|
@ -67,32 +75,6 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
return `${m.replace(/\.0$/, "")}M docs`;
|
||||
};
|
||||
|
||||
// Format last indexed date with contextual messages
|
||||
const formatLastIndexedDate = (dateString: string): string => {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const minutesAgo = differenceInMinutes(now, date);
|
||||
const daysAgo = differenceInDays(now, date);
|
||||
|
||||
if (minutesAgo < 1) return "Just now";
|
||||
if (minutesAgo < 60) return `${minutesAgo} ${minutesAgo === 1 ? "minute" : "minutes"} ago`;
|
||||
if (isToday(date)) return `Today at ${format(date, "h:mm a")}`;
|
||||
if (isYesterday(date)) return `Yesterday at ${format(date, "h:mm a")}`;
|
||||
if (daysAgo < 7) return `${daysAgo} ${daysAgo === 1 ? "day" : "days"} ago`;
|
||||
return format(date, "MMM d, yyyy");
|
||||
};
|
||||
|
||||
// Get most recent last indexed date from a list of connectors
|
||||
const getMostRecentLastIndexed = (
|
||||
connectorsList: SearchSourceConnector[]
|
||||
): string | undefined => {
|
||||
return connectorsList.reduce<string | undefined>((latest, c) => {
|
||||
if (!c.last_indexed_at) return latest;
|
||||
if (!latest) return c.last_indexed_at;
|
||||
return new Date(c.last_indexed_at) > new Date(latest) ? c.last_indexed_at : latest;
|
||||
}, undefined);
|
||||
};
|
||||
|
||||
// Document types that should be shown as standalone cards (not from connectors)
|
||||
const standaloneDocumentTypes = ["EXTENSION", "FILE", "NOTE", "YOUTUBE_VIDEO", "CRAWLED_URL"];
|
||||
|
||||
|
|
@ -190,7 +172,6 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
documentTypeCounts
|
||||
);
|
||||
const accountCount = typeConnectors.length;
|
||||
const mostRecentLastIndexed = getMostRecentLastIndexed(typeConnectors);
|
||||
|
||||
const handleManageClick = () => {
|
||||
if (onViewAccountsList) {
|
||||
|
|
@ -204,10 +185,10 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
<div
|
||||
key={`oauth-type-${connectorType}`}
|
||||
className={cn(
|
||||
"relative flex items-center gap-4 p-4 rounded-xl border border-border transition-all",
|
||||
"relative flex items-center gap-4 p-4 rounded-xl transition-all",
|
||||
isAnyIndexing
|
||||
? "bg-primary/5 border-primary/20"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10"
|
||||
? "bg-primary/5 border-0"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10 border border-border"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
|
|
@ -225,22 +206,17 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
{isAnyIndexing ? (
|
||||
<p className="text-[11px] text-primary mt-1 flex items-center gap-1.5">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Indexing...
|
||||
Syncing
|
||||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap">
|
||||
{mostRecentLastIndexed
|
||||
? `Last indexed: ${formatLastIndexedDate(mostRecentLastIndexed)}`
|
||||
: "Never indexed"}
|
||||
<p className="text-[10px] text-muted-foreground mt-1 flex items-center gap-1.5">
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
<p className="text-[10px] text-muted-foreground mt-0.5 flex items-center gap-1.5">
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
|
|
@ -257,22 +233,19 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
{/* Non-OAuth Connectors - Individual Cards */}
|
||||
{filteredNonOAuthConnectors.map((connector) => {
|
||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||
const activeTask = logsSummary?.active_tasks?.find(
|
||||
(task: LogActiveTask) => task.connector_id === connector.id
|
||||
);
|
||||
const documentCount = getDocumentCountForConnector(
|
||||
connector.connector_type,
|
||||
documentTypeCounts
|
||||
);
|
||||
|
||||
const isMCPConnector = connector.connector_type === "MCP_CONNECTOR";
|
||||
return (
|
||||
<div
|
||||
key={`connector-${connector.id}`}
|
||||
className={cn(
|
||||
"flex items-center gap-4 p-4 rounded-xl border border-border transition-all",
|
||||
"flex items-center gap-4 p-4 rounded-xl transition-all",
|
||||
isIndexing
|
||||
? "bg-primary/5 border-primary/20"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10"
|
||||
? "bg-primary/5 border-0"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10 border border-border"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
|
|
@ -286,29 +259,21 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
{getConnectorIcon(connector.connector_type, "size-6")}
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-[14px] font-semibold leading-tight truncate">
|
||||
{connector.name}
|
||||
</p>
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="text-[14px] font-semibold leading-tight">
|
||||
{connector.name}
|
||||
</p>
|
||||
</div>
|
||||
{isIndexing ? (
|
||||
<p className="text-[11px] text-primary mt-1 flex items-center gap-1.5">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Indexing...
|
||||
{activeTask?.message && (
|
||||
<span className="text-muted-foreground truncate max-w-[150px]">
|
||||
• {activeTask.message}
|
||||
</span>
|
||||
)}
|
||||
Syncing
|
||||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap">
|
||||
{connector.last_indexed_at
|
||||
? `Last indexed: ${formatLastIndexedDate(connector.last_indexed_at)}`
|
||||
: "Never indexed"}
|
||||
) : !isMCPConnector ? (
|
||||
<p className="text-[10px] text-muted-foreground mt-1">
|
||||
{formatDocumentCount(documentCount)}
|
||||
</p>
|
||||
)}
|
||||
<p className="text-[10px] text-muted-foreground mt-0.5">
|
||||
{formatDocumentCount(documentCount)}
|
||||
</p>
|
||||
) : null}
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
|
|
@ -362,19 +327,12 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
) : (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-center">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-muted mb-4">
|
||||
<Cable className="size-8 text-muted-foreground/50" />
|
||||
<Cable className="size-8 text-muted-foreground" />
|
||||
</div>
|
||||
<h4 className="text-lg font-semibold">No active sources</h4>
|
||||
<p className="text-sm text-muted-foreground mt-1 max-w-[280px]">
|
||||
Connect your first service to start searching across all your data.
|
||||
</p>
|
||||
<Button
|
||||
variant="link"
|
||||
className="mt-6 text-primary hover:underline"
|
||||
onClick={() => onTabChange("all")}
|
||||
>
|
||||
Browse available connectors
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</TabsContent>
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
"use client";
|
||||
|
||||
import { Plus } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import type { LogActiveTask, LogSummary } from "@/contracts/types/log.types";
|
||||
import { ConnectorCard } from "../components/connector-card";
|
||||
import { CRAWLERS, OAUTH_CONNECTORS, OTHER_CONNECTORS } from "../constants/connector-constants";
|
||||
import { getDocumentCountForConnector } from "../utils/connector-document-mapping";
|
||||
|
|
@ -30,7 +28,6 @@ interface AllConnectorsTabProps {
|
|||
allConnectors: SearchSourceConnector[] | undefined;
|
||||
documentTypeCounts?: Record<string, number>;
|
||||
indexingConnectorIds?: Set<number>;
|
||||
logsSummary?: LogSummary;
|
||||
onConnectOAuth: (connector: (typeof OAUTH_CONNECTORS)[number]) => void;
|
||||
onConnectNonOAuth?: (connectorType: string) => void;
|
||||
onCreateWebcrawler?: () => void;
|
||||
|
|
@ -41,13 +38,11 @@ interface AllConnectorsTabProps {
|
|||
|
||||
export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
||||
searchQuery,
|
||||
searchSpaceId,
|
||||
connectedTypes,
|
||||
connectingId,
|
||||
allConnectors,
|
||||
documentTypeCounts,
|
||||
indexingConnectorIds,
|
||||
logsSummary,
|
||||
onConnectOAuth,
|
||||
onConnectNonOAuth,
|
||||
onCreateWebcrawler,
|
||||
|
|
@ -55,14 +50,6 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
onManage,
|
||||
onViewAccountsList,
|
||||
}) => {
|
||||
// Helper to find active task for a connector
|
||||
const getActiveTaskForConnector = (connectorId: number): LogActiveTask | undefined => {
|
||||
if (!logsSummary?.active_tasks) return undefined;
|
||||
return logsSummary.active_tasks.find(
|
||||
(task: LogActiveTask) => task.connector_id === connectorId
|
||||
);
|
||||
};
|
||||
|
||||
// Filter connectors based on search
|
||||
const filteredOAuth = OAUTH_CONNECTORS.filter(
|
||||
(c) =>
|
||||
|
|
@ -103,6 +90,8 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
)
|
||||
: [];
|
||||
|
||||
const accountCount = typeConnectors.length;
|
||||
|
||||
// Get the most recent last_indexed_at across all accounts
|
||||
const mostRecentLastIndexed = typeConnectors.reduce<string | undefined>(
|
||||
(latest, c) => {
|
||||
|
|
@ -123,11 +112,6 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
// Check if any account is currently indexing
|
||||
const isIndexing = typeConnectors.some((c) => indexingConnectorIds?.has(c.id));
|
||||
|
||||
// Get active task from any indexing account
|
||||
const activeTask = typeConnectors
|
||||
.map((c) => getActiveTaskForConnector(c.id))
|
||||
.find((task) => task !== undefined);
|
||||
|
||||
return (
|
||||
<ConnectorCard
|
||||
key={connector.id}
|
||||
|
|
@ -138,10 +122,8 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
isConnected={isConnected}
|
||||
isConnecting={isConnecting}
|
||||
documentCount={documentCount}
|
||||
accountCount={typeConnectors.length}
|
||||
lastIndexedAt={mostRecentLastIndexed}
|
||||
accountCount={accountCount}
|
||||
isIndexing={isIndexing}
|
||||
activeTask={activeTask}
|
||||
onConnect={() => onConnectOAuth(connector)}
|
||||
onManage={
|
||||
isConnected && onViewAccountsList
|
||||
|
|
@ -179,9 +161,16 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
documentTypeCounts
|
||||
);
|
||||
const isIndexing = actualConnector && indexingConnectorIds?.has(actualConnector.id);
|
||||
const activeTask = actualConnector
|
||||
? getActiveTaskForConnector(actualConnector.id)
|
||||
: undefined;
|
||||
|
||||
// For MCP connectors, count total MCP connectors instead of document count
|
||||
const isMCP = connector.connectorType === EnumConnectorName.MCP_CONNECTOR;
|
||||
const mcpConnectorCount =
|
||||
isMCP && allConnectors
|
||||
? allConnectors.filter(
|
||||
(c: SearchSourceConnector) =>
|
||||
c.connector_type === EnumConnectorName.MCP_CONNECTOR
|
||||
).length
|
||||
: undefined;
|
||||
|
||||
const handleConnect = onConnectNonOAuth
|
||||
? () => onConnectNonOAuth(connector.connectorType)
|
||||
|
|
@ -197,9 +186,8 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
isConnected={isConnected}
|
||||
isConnecting={isConnecting}
|
||||
documentCount={documentCount}
|
||||
lastIndexedAt={actualConnector?.last_indexed_at}
|
||||
connectorCount={mcpConnectorCount}
|
||||
isIndexing={isIndexing}
|
||||
activeTask={activeTask}
|
||||
onConnect={handleConnect}
|
||||
onManage={
|
||||
actualConnector && onManage ? () => onManage(actualConnector) : undefined
|
||||
|
|
@ -240,9 +228,6 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
? getDocumentCountForConnector(crawler.connectorType, documentTypeCounts)
|
||||
: undefined;
|
||||
const isIndexing = actualConnector && indexingConnectorIds?.has(actualConnector.id);
|
||||
const activeTask = actualConnector
|
||||
? getActiveTaskForConnector(actualConnector.id)
|
||||
: undefined;
|
||||
|
||||
const handleConnect =
|
||||
isYouTube && onCreateYouTubeCrawler
|
||||
|
|
@ -267,9 +252,7 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
isConnected={isConnected}
|
||||
isConnecting={isConnecting}
|
||||
documentCount={documentCount}
|
||||
lastIndexedAt={actualConnector?.last_indexed_at}
|
||||
isIndexing={isIndexing}
|
||||
activeTask={activeTask}
|
||||
onConnect={handleConnect}
|
||||
onManage={
|
||||
actualConnector && onManage ? () => onManage(actualConnector) : undefined
|
||||
|
|
|
|||
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* MCP Configuration Validator Utility
|
||||
*
|
||||
* Shared validation and parsing logic for MCP (Model Context Protocol) server configurations.
|
||||
*
|
||||
* Features:
|
||||
* - Zod schema validation for runtime type safety
|
||||
* - Configuration caching to avoid repeated parsing (5-minute TTL)
|
||||
* - Standardized error messages
|
||||
* - Connection testing utilities
|
||||
*
|
||||
* Usage:
|
||||
* ```typescript
|
||||
* // Parse and validate config
|
||||
* const result = parseMCPConfig(jsonString);
|
||||
* if (result.config) {
|
||||
* // Valid config
|
||||
* } else {
|
||||
* // Show result.error to user
|
||||
* }
|
||||
*
|
||||
* // Test connection
|
||||
* const testResult = await testMCPConnection(config);
|
||||
* if (testResult.status === "success") {
|
||||
* console.log(`Found ${testResult.tools.length} tools`);
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @module mcp-config-validator
|
||||
*/
|
||||
|
||||
import { z } from "zod";
|
||||
import type { MCPServerConfig, MCPToolDefinition } from "@/contracts/types/mcp.types";
|
||||
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
|
||||
|
||||
/**
|
||||
* Zod schema for MCP server configuration
|
||||
* Supports both stdio (local process) and HTTP (remote server) transports
|
||||
*
|
||||
* Exported for advanced use cases (e.g., form builders)
|
||||
*/
|
||||
const StdioConfigSchema = z.object({
|
||||
name: z.string().optional(),
|
||||
command: z.string().min(1, "Command cannot be empty"),
|
||||
args: z.array(z.string()).optional().default([]),
|
||||
env: z.record(z.string(), z.string()).optional().default({}),
|
||||
transport: z.enum(["stdio"]).optional().default("stdio"),
|
||||
});
|
||||
|
||||
const HttpConfigSchema = z.object({
|
||||
name: z.string().optional(),
|
||||
url: z.string().url("URL must be a valid URL"),
|
||||
headers: z.record(z.string(), z.string()).optional().default({}),
|
||||
transport: z.enum(["streamable-http", "http", "sse"]),
|
||||
});
|
||||
|
||||
export const MCPServerConfigSchema = z.union([StdioConfigSchema, HttpConfigSchema]);
|
||||
|
||||
/**
|
||||
* Shared MCP configuration validation result
|
||||
*/
|
||||
export interface MCPConfigValidationResult {
|
||||
config: MCPServerConfig | null;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared MCP connection test result
|
||||
*/
|
||||
export interface MCPConnectionTestResult {
|
||||
status: "success" | "error";
|
||||
message: string;
|
||||
tools: MCPToolDefinition[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Cache for parsed configurations to avoid re-parsing
|
||||
* Key: JSON string, Value: { config, timestamp }
|
||||
*/
|
||||
const configCache = new Map<string, { config: MCPServerConfig; timestamp: number }>();
|
||||
const CACHE_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
/**
|
||||
* Clear expired entries from config cache
|
||||
*/
|
||||
const clearExpiredCache = () => {
|
||||
const now = Date.now();
|
||||
for (const [key, value] of configCache.entries()) {
|
||||
if (now - value.timestamp > CACHE_TTL) {
|
||||
configCache.delete(key);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse and validate MCP server configuration from JSON string
|
||||
* Uses Zod for schema validation and caching to avoid re-parsing
|
||||
* @param configJson - JSON string containing MCP server configuration
|
||||
* @returns Validation result with parsed config or error message
|
||||
*/
|
||||
export const parseMCPConfig = (configJson: string): MCPConfigValidationResult => {
|
||||
// Check cache first
|
||||
const cached = configCache.get(configJson);
|
||||
if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
|
||||
console.log("[MCP Validator] ✅ Using cached config");
|
||||
return { config: cached.config, error: null };
|
||||
}
|
||||
|
||||
console.log("[MCP Validator] 🔍 Parsing new config...");
|
||||
|
||||
// Clean up expired cache entries periodically
|
||||
if (configCache.size > 100) {
|
||||
clearExpiredCache();
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(configJson);
|
||||
|
||||
// Validate that it's an object, not an array
|
||||
if (Array.isArray(parsed)) {
|
||||
console.error("[MCP Validator] ❌ Error: Config is an array, expected object");
|
||||
return {
|
||||
config: null,
|
||||
error: "Please provide a single server configuration object, not an array",
|
||||
};
|
||||
}
|
||||
|
||||
// Use Zod schema validation for robust type checking
|
||||
const result = MCPServerConfigSchema.safeParse(parsed);
|
||||
|
||||
if (!result.success) {
|
||||
// Format Zod validation errors for user-friendly display
|
||||
const firstError = result.error.issues[0];
|
||||
const fieldPath = firstError.path.join(".");
|
||||
|
||||
// Clean up error message - remove technical Zod jargon
|
||||
let errorMsg = firstError.message;
|
||||
|
||||
// Replace technical error messages with user-friendly ones
|
||||
if (errorMsg.includes("expected string, received undefined")) {
|
||||
errorMsg = fieldPath ? `The '${fieldPath}' field is required` : "This field is required";
|
||||
} else if (errorMsg.includes("Invalid input")) {
|
||||
errorMsg = fieldPath ? `The '${fieldPath}' field has an invalid value` : "Invalid value";
|
||||
} else if (fieldPath && !errorMsg.toLowerCase().includes(fieldPath.toLowerCase())) {
|
||||
// If error message doesn't mention the field name, prepend it
|
||||
errorMsg = `The '${fieldPath}' field: ${errorMsg}`;
|
||||
}
|
||||
|
||||
console.error("[MCP Validator] ❌ Validation error:", errorMsg);
|
||||
console.error("[MCP Validator] Full Zod errors:", result.error.issues);
|
||||
|
||||
return {
|
||||
config: null,
|
||||
error: errorMsg,
|
||||
};
|
||||
}
|
||||
|
||||
// Build config based on transport type
|
||||
const config: MCPServerConfig =
|
||||
result.data.transport === "stdio" || !result.data.transport
|
||||
? {
|
||||
command: (result.data as z.infer<typeof StdioConfigSchema>).command,
|
||||
args: (result.data as z.infer<typeof StdioConfigSchema>).args,
|
||||
env: (result.data as z.infer<typeof StdioConfigSchema>).env,
|
||||
transport: "stdio" as const,
|
||||
}
|
||||
: {
|
||||
url: (result.data as z.infer<typeof HttpConfigSchema>).url,
|
||||
headers: (result.data as z.infer<typeof HttpConfigSchema>).headers,
|
||||
transport: result.data.transport as "streamable-http" | "http" | "sse",
|
||||
};
|
||||
|
||||
// Cache the successfully parsed config
|
||||
configCache.set(configJson, {
|
||||
config,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
console.log("[MCP Validator] ✅ Config parsed successfully:", config);
|
||||
|
||||
return {
|
||||
config,
|
||||
error: null,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : "Invalid JSON";
|
||||
console.error("[MCP Validator] ❌ JSON parse error:", errorMsg);
|
||||
return {
|
||||
config: null,
|
||||
error: errorMsg,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Test connection to MCP server
|
||||
* @param serverConfig - MCP server configuration to test
|
||||
* @returns Connection test result with status, message, and available tools
|
||||
*/
|
||||
export const testMCPConnection = async (
|
||||
serverConfig: MCPServerConfig
|
||||
): Promise<MCPConnectionTestResult> => {
|
||||
try {
|
||||
const result = await connectorsApiService.testMCPConnection(serverConfig);
|
||||
|
||||
if (result.status === "success") {
|
||||
return {
|
||||
status: "success",
|
||||
message: `Successfully connected. Found ${result.tools.length} tool${result.tools.length !== 1 ? "s" : ""}.`,
|
||||
tools: result.tools,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
status: "error",
|
||||
message: result.message || "Failed to connect",
|
||||
tools: [],
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
status: "error",
|
||||
message: error instanceof Error ? error.message : "Failed to connect",
|
||||
tools: [],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Extract server name from MCP config JSON with caching
|
||||
* @param configJson - JSON string containing MCP server configuration
|
||||
* @returns Server name if found, otherwise default name
|
||||
*/
|
||||
export const extractServerName = (configJson: string): string => {
|
||||
try {
|
||||
const parsed = JSON.parse(configJson);
|
||||
|
||||
// Use Zod to validate and extract name field safely
|
||||
const nameSchema = z.object({ name: z.string().optional() });
|
||||
const result = nameSchema.safeParse(parsed);
|
||||
|
||||
if (result.success && result.data.name) {
|
||||
return result.data.name;
|
||||
}
|
||||
} catch {
|
||||
// Return default if parsing fails
|
||||
}
|
||||
return "MCP Server";
|
||||
};
|
||||
|
||||
/**
|
||||
* Clear the configuration cache
|
||||
* Useful for testing or when memory management is needed
|
||||
*/
|
||||
export const clearConfigCache = () => {
|
||||
configCache.clear();
|
||||
};
|
||||
|
||||
/**
|
||||
* Get cache statistics for monitoring/debugging
|
||||
*/
|
||||
export const getConfigCacheStats = () => {
|
||||
return {
|
||||
size: configCache.size,
|
||||
entries: Array.from(configCache.entries()).map(([key, value]) => ({
|
||||
configPreview: key.substring(0, 50) + (key.length > 50 ? "..." : ""),
|
||||
timestamp: new Date(value.timestamp).toISOString(),
|
||||
age: Date.now() - value.timestamp,
|
||||
})),
|
||||
};
|
||||
};
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
"use client";
|
||||
|
||||
import { differenceInDays, differenceInMinutes, format, isToday, isYesterday } from "date-fns";
|
||||
import { ArrowLeft, Loader2, Plus } from "lucide-react";
|
||||
import { ArrowLeft, Loader2, Plus, Server } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import type { LogActiveTask, LogSummary } from "@/contracts/types/log.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
||||
|
|
@ -16,11 +16,19 @@ interface ConnectorAccountsListViewProps {
|
|||
connectorTitle: string;
|
||||
connectors: SearchSourceConnector[];
|
||||
indexingConnectorIds: Set<number>;
|
||||
logsSummary: LogSummary | undefined;
|
||||
onBack: () => void;
|
||||
onManage: (connector: SearchSourceConnector) => void;
|
||||
onAddAccount: () => void;
|
||||
isConnecting?: boolean;
|
||||
addButtonText?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a connector type is indexable
|
||||
*/
|
||||
function isIndexableConnector(connectorType: string): boolean {
|
||||
const nonIndexableTypes = ["MCP_CONNECTOR"];
|
||||
return !nonIndexableTypes.includes(connectorType);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -60,11 +68,11 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
connectorTitle,
|
||||
connectors,
|
||||
indexingConnectorIds,
|
||||
logsSummary,
|
||||
onBack,
|
||||
onManage,
|
||||
onAddAccount,
|
||||
isConnecting = false,
|
||||
addButtonText,
|
||||
}) => {
|
||||
// Get connector status
|
||||
const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus();
|
||||
|
|
@ -75,6 +83,22 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
// Filter connectors to only show those of this type
|
||||
const typeConnectors = connectors.filter((c) => c.connector_type === connectorType);
|
||||
|
||||
// Determine button text - default to "Add Account" unless specified
|
||||
const buttonText =
|
||||
addButtonText ||
|
||||
(connectorType === EnumConnectorName.MCP_CONNECTOR ? "Add New MCP Server" : "Add Account");
|
||||
const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR;
|
||||
|
||||
// Helper to get display name for connector (handles MCP server name extraction)
|
||||
const getDisplayName = (connector: SearchSourceConnector): string => {
|
||||
if (isMCP) {
|
||||
// For MCP, extract server name from config if available
|
||||
const serverName = connector.config?.server_config?.name || connector.name;
|
||||
return serverName;
|
||||
}
|
||||
return getConnectorDisplayName(connector.name);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Header */}
|
||||
|
|
@ -110,22 +134,22 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
onClick={onAddAccount}
|
||||
disabled={isConnecting || !isEnabled}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 sm:gap-2 px-2 sm:px-3 py-1.5 sm:py-2 rounded-lg border-2 border-dashed text-left transition-all duration-200 shrink-0 self-center sm:self-auto sm:w-auto",
|
||||
"flex items-center justify-center gap-1.5 h-8 px-3 rounded-md border-2 border-dashed text-xs sm:text-sm transition-all duration-200 shrink-0 w-full sm:w-auto",
|
||||
!isEnabled
|
||||
? "border-border/30 opacity-50 cursor-not-allowed"
|
||||
: "border-primary/50 hover:bg-primary/5",
|
||||
: "border-slate-400/20 dark:border-white/20 hover:bg-primary/5",
|
||||
isConnecting && "opacity-50 cursor-not-allowed"
|
||||
)}
|
||||
>
|
||||
<div className="flex h-5 w-5 sm:h-6 sm:w-6 items-center justify-center rounded-md bg-primary/10 shrink-0">
|
||||
<div className="flex h-5 w-5 items-center justify-center rounded-md bg-primary/10 shrink-0">
|
||||
{isConnecting ? (
|
||||
<Loader2 className="size-3 sm:size-3.5 animate-spin text-primary" />
|
||||
<Loader2 className="size-3 animate-spin text-primary" />
|
||||
) : (
|
||||
<Plus className="size-3 sm:size-3.5 text-primary" />
|
||||
<Plus className="size-3 text-primary" />
|
||||
)}
|
||||
</div>
|
||||
<span className="text-[11px] sm:text-[12px] font-medium">
|
||||
{isConnecting ? "Connecting..." : "Add Account"}
|
||||
<span className="text-xs sm:text-sm font-medium">
|
||||
{isConnecting ? "Connecting" : buttonText}
|
||||
</span>
|
||||
</button>
|
||||
</div>
|
||||
|
|
@ -134,67 +158,81 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
{/* Content */}
|
||||
<div className="flex-1 overflow-y-auto px-6 sm:px-12 pt-0 sm:pt-6 pb-6 sm:pb-8">
|
||||
{/* Connected Accounts Grid */}
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||
{typeConnectors.map((connector) => {
|
||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||
const activeTask = logsSummary?.active_tasks?.find(
|
||||
(task: LogActiveTask) => task.connector_id === connector.id
|
||||
);
|
||||
{typeConnectors.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center py-12 text-center">
|
||||
<div className="h-16 w-16 rounded-full bg-slate-400/5 dark:bg-white/5 flex items-center justify-center mb-4">
|
||||
{isMCP ? (
|
||||
<Server className="h-8 w-8 text-muted-foreground" />
|
||||
) : (
|
||||
getConnectorIcon(connectorType, "size-8")
|
||||
)}
|
||||
</div>
|
||||
<h3 className="text-sm font-medium mb-1">
|
||||
{isMCP ? "No MCP Servers" : `No ${connectorTitle} Accounts`}
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground max-w-[280px]">
|
||||
{isMCP
|
||||
? "Get started by adding your first Model Context Protocol server"
|
||||
: `Get started by connecting your first ${connectorTitle} account`}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||
{typeConnectors.map((connector) => {
|
||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={connector.id}
|
||||
className={cn(
|
||||
"flex items-center gap-4 p-4 rounded-xl border border-border transition-all",
|
||||
isIndexing
|
||||
? "bg-primary/5 border-primary/20"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10"
|
||||
)}
|
||||
>
|
||||
return (
|
||||
<div
|
||||
key={connector.id}
|
||||
className={cn(
|
||||
"flex h-12 w-12 items-center justify-center rounded-lg border shrink-0",
|
||||
"flex items-center gap-4 p-4 rounded-xl transition-all",
|
||||
isIndexing
|
||||
? "bg-primary/10 border-primary/20"
|
||||
: "bg-slate-400/5 dark:bg-white/5 border-slate-400/5 dark:border-white/5"
|
||||
? "bg-primary/5 border-0"
|
||||
: "bg-slate-400/5 dark:bg-white/5 hover:bg-slate-400/10 dark:hover:bg-white/10 border border-border"
|
||||
)}
|
||||
>
|
||||
{getConnectorIcon(connector.connector_type, "size-6")}
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-[14px] font-semibold leading-tight truncate">
|
||||
{getConnectorDisplayName(connector.name)}
|
||||
</p>
|
||||
{isIndexing ? (
|
||||
<p className="text-[11px] text-primary mt-1 flex items-center gap-1.5">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Indexing...
|
||||
{activeTask?.message && (
|
||||
<span className="text-muted-foreground truncate max-w-[100px]">
|
||||
• {activeTask.message}
|
||||
</span>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-12 w-12 items-center justify-center rounded-lg border shrink-0",
|
||||
isIndexing
|
||||
? "bg-primary/10 border-primary/20"
|
||||
: "bg-slate-400/5 dark:bg-white/5 border-slate-400/5 dark:border-white/5"
|
||||
)}
|
||||
>
|
||||
{getConnectorIcon(connector.connector_type, "size-6")}
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-[14px] font-semibold leading-tight truncate">
|
||||
{getDisplayName(connector)}
|
||||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap truncate">
|
||||
{connector.last_indexed_at
|
||||
? `Last indexed: ${formatLastIndexedDate(connector.last_indexed_at)}`
|
||||
: "Never indexed"}
|
||||
</p>
|
||||
)}
|
||||
{isIndexing ? (
|
||||
<p className="text-[11px] text-primary mt-1 flex items-center gap-1.5">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Syncing
|
||||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap truncate">
|
||||
{isIndexableConnector(connector.connector_type)
|
||||
? connector.last_indexed_at
|
||||
? `Last indexed: ${formatLastIndexedDate(connector.last_indexed_at)}`
|
||||
: "Never indexed"
|
||||
: "Active"}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
||||
onClick={() => onManage(connector)}
|
||||
>
|
||||
Manage
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80 shrink-0"
|
||||
onClick={() => onManage(connector)}
|
||||
>
|
||||
Manage
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Upload } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import {
|
||||
createContext,
|
||||
type FC,
|
||||
|
|
@ -85,13 +84,11 @@ const DocumentUploadPopupContent: FC<{
|
|||
onOpenChange: (open: boolean) => void;
|
||||
}> = ({ isOpen, onOpenChange }) => {
|
||||
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
|
||||
const router = useRouter();
|
||||
|
||||
if (!searchSpaceId) return null;
|
||||
|
||||
const handleSuccess = () => {
|
||||
onOpenChange(false);
|
||||
router.push(`/dashboard/${searchSpaceId}/documents`);
|
||||
};
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {
|
|||
useState,
|
||||
} from "react";
|
||||
import ReactDOMServer from "react-dom/server";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
|
|
@ -166,12 +167,19 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
chip.setAttribute(CHIP_DOCTYPE_ATTR, doc.document_type ?? "UNKNOWN");
|
||||
chip.contentEditable = "false";
|
||||
chip.className =
|
||||
"inline-flex items-center gap-0.5 mx-0.5 pl-1 pr-0.5 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary border border-primary/10 select-none";
|
||||
"inline-flex items-center gap-1 mx-0.5 pl-1 pr-0.5 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none";
|
||||
chip.style.userSelect = "none";
|
||||
chip.style.verticalAlign = "baseline";
|
||||
|
||||
// Add document type icon
|
||||
const iconSpan = document.createElement("span");
|
||||
iconSpan.className = "shrink-0 flex items-center text-muted-foreground";
|
||||
iconSpan.innerHTML = ReactDOMServer.renderToString(
|
||||
getConnectorIcon(doc.document_type ?? "UNKNOWN", "h-3 w-3")
|
||||
);
|
||||
|
||||
const titleSpan = document.createElement("span");
|
||||
titleSpan.className = "max-w-[80px] truncate";
|
||||
titleSpan.className = "max-w-[120px] truncate";
|
||||
titleSpan.textContent = doc.title;
|
||||
titleSpan.title = doc.title;
|
||||
|
||||
|
|
@ -197,6 +205,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
focusAtEnd();
|
||||
};
|
||||
|
||||
chip.appendChild(iconSpan);
|
||||
chip.appendChild(titleSpan);
|
||||
chip.appendChild(removeBtn);
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,10 @@ export const ThinkingStepsDisplay: FC<{ steps: ThinkingStep[]; isThreadRunning?:
|
|||
{/* Step dot - on top of line */}
|
||||
<div className="relative z-10 mt-[7px] flex shrink-0 items-center justify-center">
|
||||
{effectiveStatus === "in_progress" ? (
|
||||
<span className="size-2 rounded-full bg-muted-foreground/30" />
|
||||
<span className="relative flex size-2">
|
||||
<span className="absolute inline-flex size-full animate-ping rounded-full bg-primary/60" />
|
||||
<span className="relative inline-flex size-2 rounded-full bg-primary" />
|
||||
</span>
|
||||
) : (
|
||||
<span className="size-2 rounded-full bg-muted-foreground/30" />
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,71 +0,0 @@
|
|||
import { useAtomValue } from "jotai";
|
||||
import type { FC } from "react";
|
||||
import { useMemo } from "react";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { Composer } from "@/components/assistant-ui/composer";
|
||||
|
||||
const getTimeBasedGreeting = (userEmail?: string): string => {
|
||||
const hour = new Date().getHours();
|
||||
|
||||
// Extract first name from email if available
|
||||
const firstName = userEmail
|
||||
? userEmail.split("@")[0].split(".")[0].charAt(0).toUpperCase() +
|
||||
userEmail.split("@")[0].split(".")[0].slice(1)
|
||||
: null;
|
||||
|
||||
// Array of greeting variations for each time period
|
||||
const morningGreetings = ["Good morning", "Fresh start today", "Morning", "Hey there"];
|
||||
|
||||
const afternoonGreetings = ["Good afternoon", "Afternoon", "Hey there", "Hi there"];
|
||||
|
||||
const eveningGreetings = ["Good evening", "Evening", "Hey there", "Hi there"];
|
||||
|
||||
const nightGreetings = ["Good night", "Evening", "Hey there", "Winding down"];
|
||||
|
||||
const lateNightGreetings = ["Still up", "Night owl mode", "Up past bedtime", "Hi there"];
|
||||
|
||||
// Select a random greeting based on time
|
||||
let greeting: string;
|
||||
if (hour < 5) {
|
||||
// Late night: midnight to 5 AM
|
||||
greeting = lateNightGreetings[Math.floor(Math.random() * lateNightGreetings.length)];
|
||||
} else if (hour < 12) {
|
||||
greeting = morningGreetings[Math.floor(Math.random() * morningGreetings.length)];
|
||||
} else if (hour < 18) {
|
||||
greeting = afternoonGreetings[Math.floor(Math.random() * afternoonGreetings.length)];
|
||||
} else if (hour < 22) {
|
||||
greeting = eveningGreetings[Math.floor(Math.random() * eveningGreetings.length)];
|
||||
} else {
|
||||
// Night: 10 PM to midnight
|
||||
greeting = nightGreetings[Math.floor(Math.random() * nightGreetings.length)];
|
||||
}
|
||||
|
||||
// Add personalization with first name if available
|
||||
if (firstName) {
|
||||
return `${greeting}, ${firstName}!`;
|
||||
}
|
||||
|
||||
return `${greeting}!`;
|
||||
};
|
||||
|
||||
export const ThreadWelcome: FC = () => {
|
||||
const { data: user } = useAtomValue(currentUserAtom);
|
||||
|
||||
// Memoize greeting so it doesn't change on re-renders (only on user change)
|
||||
const greeting = useMemo(() => getTimeBasedGreeting(user?.email), [user?.email]);
|
||||
|
||||
return (
|
||||
<div className="aui-thread-welcome-root mx-auto flex w-full max-w-(--thread-max-width) grow flex-col items-center px-4 relative">
|
||||
{/* Greeting positioned above the composer - fixed position */}
|
||||
<div className="aui-thread-welcome-message absolute bottom-[calc(50%+5rem)] left-0 right-0 flex flex-col items-center text-center">
|
||||
<h1 className="aui-thread-welcome-message-inner fade-in slide-in-from-bottom-2 animate-in text-3xl md:text-5xl delay-100 duration-500 ease-out fill-mode-both">
|
||||
{greeting}
|
||||
</h1>
|
||||
</div>
|
||||
{/* Composer - top edge fixed, expands downward only */}
|
||||
<div className="fade-in slide-in-from-bottom-3 animate-in delay-200 duration-500 ease-out fill-mode-both w-full flex items-start justify-center absolute top-[calc(50%-3.5rem)] left-0 right-0">
|
||||
<Composer />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -26,6 +26,7 @@ import {
|
|||
import { useParams } from "next/navigation";
|
||||
import { type FC, useCallback, useContext, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { showCommentsGutterAtom } from "@/atoms/chat/current-thread.atom";
|
||||
import {
|
||||
mentionedDocumentIdsAtom,
|
||||
mentionedDocumentsAtom,
|
||||
|
|
@ -36,6 +37,7 @@ import {
|
|||
newLLMConfigsAtom,
|
||||
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ComposerAddAttachment, ComposerAttachments } from "@/components/assistant-ui/attachment";
|
||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||
import {
|
||||
|
|
@ -59,57 +61,63 @@ import { Button } from "@/components/ui/button";
|
|||
import type { Document } from "@/contracts/types/document.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
/**
|
||||
* Props for the Thread component
|
||||
*/
|
||||
interface ThreadProps {
|
||||
messageThinkingSteps?: Map<string, ThinkingStep[]>;
|
||||
/** Optional header component to render at the top of the viewport (sticky) */
|
||||
header?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const Thread: FC<ThreadProps> = ({ messageThinkingSteps = new Map(), header }) => {
|
||||
return (
|
||||
<ThinkingStepsContext.Provider value={messageThinkingSteps}>
|
||||
<ThreadPrimitive.Root
|
||||
className="aui-root aui-thread-root @container flex h-full min-h-0 flex-col bg-background"
|
||||
style={{
|
||||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
||||
>
|
||||
{/* Optional sticky header for model selector etc. */}
|
||||
{header && <div className="sticky top-0 z-10 mb-4">{header}</div>}
|
||||
|
||||
<AssistantIf condition={({ thread }) => thread.isEmpty}>
|
||||
<ThreadWelcome />
|
||||
</AssistantIf>
|
||||
|
||||
<ThreadPrimitive.Messages
|
||||
components={{
|
||||
UserMessage,
|
||||
EditComposer,
|
||||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter className="aui-thread-viewport-footer sticky bottom-0 mx-auto mt-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-background pb-4 md:pb-6">
|
||||
<ThreadScrollToBottom />
|
||||
<AssistantIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
|
||||
<Composer />
|
||||
</div>
|
||||
</AssistantIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ThreadPrimitive.Root>
|
||||
<ThreadContent header={header} />
|
||||
</ThinkingStepsContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
const ThreadContent: FC<{ header?: React.ReactNode }> = ({ header }) => {
|
||||
const showGutter = useAtomValue(showCommentsGutterAtom);
|
||||
|
||||
return (
|
||||
<ThreadPrimitive.Root
|
||||
className="aui-root aui-thread-root @container flex h-full min-h-0 flex-col bg-background"
|
||||
style={{
|
||||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className={cn(
|
||||
"aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 transition-[padding] duration-300 ease-out",
|
||||
showGutter && "lg:pr-30"
|
||||
)}
|
||||
>
|
||||
{header && <div className="sticky top-0 z-10 mb-4">{header}</div>}
|
||||
|
||||
<AssistantIf condition={({ thread }) => thread.isEmpty}>
|
||||
<ThreadWelcome />
|
||||
</AssistantIf>
|
||||
|
||||
<ThreadPrimitive.Messages
|
||||
components={{
|
||||
UserMessage,
|
||||
EditComposer,
|
||||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter className="aui-thread-viewport-footer sticky bottom-0 z-20 mx-auto mt-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-background pb-4 md:pb-6">
|
||||
<ThreadScrollToBottom />
|
||||
<AssistantIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
|
||||
<Composer />
|
||||
</div>
|
||||
</AssistantIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ThreadPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
|
|
@ -124,14 +132,23 @@ const ThreadScrollToBottom: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const getTimeBasedGreeting = (userEmail?: string): string => {
|
||||
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
|
||||
const hour = new Date().getHours();
|
||||
|
||||
// Extract first name from email if available
|
||||
const firstName = userEmail
|
||||
? userEmail.split("@")[0].split(".")[0].charAt(0).toUpperCase() +
|
||||
userEmail.split("@")[0].split(".")[0].slice(1)
|
||||
: null;
|
||||
// Extract first name: prefer display_name, fall back to email extraction
|
||||
let firstName: string | null = null;
|
||||
|
||||
if (user?.display_name?.trim()) {
|
||||
// Use display_name if available and not empty
|
||||
// Extract first name from display_name (take first word)
|
||||
const nameParts = user.display_name.trim().split(/\s+/);
|
||||
firstName = nameParts[0].charAt(0).toUpperCase() + nameParts[0].slice(1).toLowerCase();
|
||||
} else if (user?.email) {
|
||||
// Fall back to email extraction if display_name is not available
|
||||
firstName =
|
||||
user.email.split("@")[0].split(".")[0].charAt(0).toUpperCase() +
|
||||
user.email.split("@")[0].split(".")[0].slice(1);
|
||||
}
|
||||
|
||||
// Array of greeting variations for each time period
|
||||
const morningGreetings = ["Good morning", "Fresh start today", "Morning", "Hey there"];
|
||||
|
|
@ -172,7 +189,7 @@ const ThreadWelcome: FC = () => {
|
|||
const { data: user } = useAtomValue(currentUserAtom);
|
||||
|
||||
// Memoize greeting so it doesn't change on re-renders (only on user change)
|
||||
const greeting = useMemo(() => getTimeBasedGreeting(user?.email), [user?.email]);
|
||||
const greeting = useMemo(() => getTimeBasedGreeting(user), [user]);
|
||||
|
||||
return (
|
||||
<div className="aui-thread-welcome-root mx-auto flex w-full max-w-(--thread-max-width) grow flex-col items-center px-4 relative">
|
||||
|
|
@ -191,7 +208,7 @@ const ThreadWelcome: FC = () => {
|
|||
};
|
||||
|
||||
const Composer: FC = () => {
|
||||
// ---- State for document mentions (using atoms to persist across remounts) ----
|
||||
// Document mention state (atoms persist across component remounts)
|
||||
const [mentionedDocuments, setMentionedDocuments] = useAtom(mentionedDocumentsAtom);
|
||||
const [showDocumentPopover, setShowDocumentPopover] = useState(false);
|
||||
const [mentionQuery, setMentionQuery] = useState("");
|
||||
|
|
@ -203,16 +220,12 @@ const Composer: FC = () => {
|
|||
const composerRuntime = useComposerRuntime();
|
||||
const hasAutoFocusedRef = useRef(false);
|
||||
|
||||
// Check if thread is empty (new chat)
|
||||
const isThreadEmpty = useAssistantState(({ thread }) => thread.isEmpty);
|
||||
|
||||
// Check if thread is currently running (streaming response)
|
||||
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
|
||||
|
||||
// Auto-focus editor when on new chat page
|
||||
// Auto-focus editor on new chat page after mount
|
||||
useEffect(() => {
|
||||
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {
|
||||
// Small delay to ensure the editor is fully mounted
|
||||
const timeoutId = setTimeout(() => {
|
||||
editorRef.current?.focus();
|
||||
hasAutoFocusedRef.current = true;
|
||||
|
|
@ -221,7 +234,7 @@ const Composer: FC = () => {
|
|||
}
|
||||
}, [isThreadEmpty]);
|
||||
|
||||
// Sync mentioned document IDs to atom for use in chat request
|
||||
// Sync mentioned document IDs to atom for inclusion in chat request payload
|
||||
useEffect(() => {
|
||||
setMentionedDocumentIds({
|
||||
surfsense_doc_ids: mentionedDocuments
|
||||
|
|
@ -233,7 +246,7 @@ const Composer: FC = () => {
|
|||
});
|
||||
}, [mentionedDocuments, setMentionedDocumentIds]);
|
||||
|
||||
// Handle text change from inline editor - sync with assistant-ui composer
|
||||
// Sync editor text with assistant-ui composer runtime
|
||||
const handleEditorChange = useCallback(
|
||||
(text: string) => {
|
||||
composerRuntime.setText(text);
|
||||
|
|
@ -241,13 +254,13 @@ const Composer: FC = () => {
|
|||
[composerRuntime]
|
||||
);
|
||||
|
||||
// Handle @ mention trigger from inline editor
|
||||
// Open document picker when @ mention is triggered
|
||||
const handleMentionTrigger = useCallback((query: string) => {
|
||||
setShowDocumentPopover(true);
|
||||
setMentionQuery(query);
|
||||
}, []);
|
||||
|
||||
// Handle mention close
|
||||
// Close document picker and reset query
|
||||
const handleMentionClose = useCallback(() => {
|
||||
if (showDocumentPopover) {
|
||||
setShowDocumentPopover(false);
|
||||
|
|
@ -255,7 +268,7 @@ const Composer: FC = () => {
|
|||
}
|
||||
}, [showDocumentPopover]);
|
||||
|
||||
// Handle keyboard navigation when popover is open
|
||||
// Keyboard navigation for document picker (arrow keys, Enter, Escape)
|
||||
const handleKeyDown = useCallback(
|
||||
(e: React.KeyboardEvent) => {
|
||||
if (showDocumentPopover) {
|
||||
|
|
@ -285,15 +298,13 @@ const Composer: FC = () => {
|
|||
[showDocumentPopover]
|
||||
);
|
||||
|
||||
// Handle submit from inline editor (Enter key)
|
||||
// Submit message (blocked during streaming or when document picker is open)
|
||||
const handleSubmit = useCallback(() => {
|
||||
// Prevent sending while a response is still streaming
|
||||
if (isThreadRunning) {
|
||||
return;
|
||||
}
|
||||
if (!showDocumentPopover) {
|
||||
composerRuntime.send();
|
||||
// Clear the editor after sending
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
setMentionedDocumentIds({
|
||||
|
|
@ -309,6 +320,7 @@ const Composer: FC = () => {
|
|||
setMentionedDocumentIds,
|
||||
]);
|
||||
|
||||
// Remove document from mentions and sync IDs to atom
|
||||
const handleDocumentRemove = useCallback(
|
||||
(docId: number, docType?: string) => {
|
||||
setMentionedDocuments((prev) => {
|
||||
|
|
@ -327,6 +339,7 @@ const Composer: FC = () => {
|
|||
[setMentionedDocuments, setMentionedDocumentIds]
|
||||
);
|
||||
|
||||
// Add selected documents from picker, insert chips, and sync IDs to atom
|
||||
const handleDocumentsMention = useCallback(
|
||||
(documents: Pick<Document, "id" | "title" | "document_type">[]) => {
|
||||
const existingKeys = new Set(mentionedDocuments.map((d) => `${d.document_type}:${d.id}`));
|
||||
|
|
@ -364,7 +377,7 @@ const Composer: FC = () => {
|
|||
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col">
|
||||
<ComposerPrimitive.AttachmentDropzone className="aui-composer-attachment-dropzone flex w-full flex-col rounded-2xl border-input bg-muted px-1 pt-2 outline-none transition-shadow data-[dragging=true]:border-ring data-[dragging=true]:border-dashed data-[dragging=true]:bg-accent/50">
|
||||
<ComposerAttachments />
|
||||
{/* -------- Inline Mention Editor -------- */}
|
||||
{/* Inline editor with @mention support */}
|
||||
<div ref={editorContainerRef} className="aui-composer-input-wrapper px-3 pt-3 pb-6">
|
||||
<InlineMentionEditor
|
||||
ref={editorRef}
|
||||
|
|
@ -379,45 +392,29 @@ const Composer: FC = () => {
|
|||
/>
|
||||
</div>
|
||||
|
||||
{/* -------- Document mention popover (rendered via portal) -------- */}
|
||||
{/* Document picker popover (portal to body for proper z-index stacking) */}
|
||||
{showDocumentPopover &&
|
||||
typeof document !== "undefined" &&
|
||||
createPortal(
|
||||
<>
|
||||
{/* Backdrop */}
|
||||
<button
|
||||
type="button"
|
||||
className="fixed inset-0 cursor-default"
|
||||
style={{ zIndex: 9998 }}
|
||||
onClick={() => setShowDocumentPopover(false)}
|
||||
aria-label="Close document picker"
|
||||
/>
|
||||
{/* Popover positioned above input */}
|
||||
<div
|
||||
className="fixed shadow-2xl rounded-lg border border-border overflow-hidden bg-popover"
|
||||
style={{
|
||||
zIndex: 9999,
|
||||
bottom: editorContainerRef.current
|
||||
? `${window.innerHeight - editorContainerRef.current.getBoundingClientRect().top + 8}px`
|
||||
: "200px",
|
||||
left: editorContainerRef.current
|
||||
? `${editorContainerRef.current.getBoundingClientRect().left}px`
|
||||
: "50%",
|
||||
}}
|
||||
>
|
||||
<DocumentMentionPicker
|
||||
ref={documentPickerRef}
|
||||
searchSpaceId={Number(search_space_id)}
|
||||
onSelectionChange={handleDocumentsMention}
|
||||
onDone={() => {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}}
|
||||
initialSelectedDocuments={mentionedDocuments}
|
||||
externalSearch={mentionQuery}
|
||||
/>
|
||||
</div>
|
||||
</>,
|
||||
<DocumentMentionPicker
|
||||
ref={documentPickerRef}
|
||||
searchSpaceId={Number(search_space_id)}
|
||||
onSelectionChange={handleDocumentsMention}
|
||||
onDone={() => {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}}
|
||||
initialSelectedDocuments={mentionedDocuments}
|
||||
externalSearch={mentionQuery}
|
||||
containerStyle={{
|
||||
bottom: editorContainerRef.current
|
||||
? `${window.innerHeight - editorContainerRef.current.getBoundingClientRect().top + 8}px`
|
||||
: "200px",
|
||||
left: editorContainerRef.current
|
||||
? `${editorContainerRef.current.getBoundingClientRect().left}px`
|
||||
: "50%",
|
||||
}}
|
||||
/>,
|
||||
document.body
|
||||
)}
|
||||
<ComposerAction />
|
||||
|
|
@ -590,17 +587,6 @@ const AssistantMessageInner: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const AssistantMessage: FC = () => {
|
||||
return (
|
||||
<MessagePrimitive.Root
|
||||
className="aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||
data-role="assistant"
|
||||
>
|
||||
<AssistantMessageInner />
|
||||
</MessagePrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
||||
const AssistantActionBar: FC = () => {
|
||||
return (
|
||||
<ActionBarPrimitive.Root
|
||||
|
|
|
|||
|
|
@ -27,12 +27,7 @@ export const TooltipIconButton = forwardRef<HTMLButtonElement, TooltipIconButton
|
|||
<span className="aui-sr-only sr-only">{tooltip}</span>
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side={side}
|
||||
className="bg-black text-white font-medium shadow-xl px-3 py-1.5 dark:bg-zinc-800 dark:text-zinc-50 border-none"
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipContent>
|
||||
<TooltipContent side={side}>{tooltip}</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ export const UserMessage: FC = () => {
|
|||
</div>
|
||||
{/* User avatar - only shown in shared chats */}
|
||||
{author && (
|
||||
<div className="shrink-0">
|
||||
<div className="shrink-0 mb-1.5">
|
||||
<UserAvatar displayName={author.displayName} avatarUrl={author.avatarUrl} />
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,303 @@
|
|||
"use client";
|
||||
|
||||
import { Send, X } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Popover, PopoverAnchor, PopoverContent } from "@/components/ui/popover";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { MemberMentionPicker } from "../member-mention-picker/member-mention-picker";
|
||||
import type { MemberOption } from "../member-mention-picker/types";
|
||||
import type { CommentComposerProps, InsertedMention, MentionState } from "./types";
|
||||
|
||||
function convertDisplayToData(displayContent: string, mentions: InsertedMention[]): string {
|
||||
let result = displayContent;
|
||||
|
||||
const sortedMentions = [...mentions].sort((a, b) => b.displayName.length - a.displayName.length);
|
||||
|
||||
for (const mention of sortedMentions) {
|
||||
const displayPattern = new RegExp(
|
||||
`@${escapeRegExp(mention.displayName)}(?=\\s|$|[.,!?;:])`,
|
||||
"g"
|
||||
);
|
||||
const dataFormat = `@[${mention.id}]`;
|
||||
result = result.replace(displayPattern, dataFormat);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function escapeRegExp(string: string): string {
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
}
|
||||
|
||||
function findMentionTrigger(
|
||||
text: string,
|
||||
cursorPos: number,
|
||||
insertedMentions: InsertedMention[]
|
||||
): { isActive: boolean; query: string; startIndex: number } {
|
||||
const textBeforeCursor = text.slice(0, cursorPos);
|
||||
|
||||
const mentionMatch = textBeforeCursor.match(/(?:^|[\s])@([^\s]*)$/);
|
||||
|
||||
if (!mentionMatch) {
|
||||
return { isActive: false, query: "", startIndex: 0 };
|
||||
}
|
||||
|
||||
const fullMatch = mentionMatch[0];
|
||||
const query = mentionMatch[1];
|
||||
const atIndex = cursorPos - query.length - 1;
|
||||
|
||||
if (atIndex > 0) {
|
||||
const charBefore = text[atIndex - 1];
|
||||
if (charBefore && !/[\s]/.test(charBefore)) {
|
||||
return { isActive: false, query: "", startIndex: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
const textFromAt = text.slice(atIndex);
|
||||
|
||||
for (const mention of insertedMentions) {
|
||||
const mentionPattern = `@${mention.displayName}`;
|
||||
|
||||
if (textFromAt.startsWith(mentionPattern)) {
|
||||
const charAfterMention = text[atIndex + mentionPattern.length];
|
||||
if (!charAfterMention || /[\s.,!?;:]/.test(charAfterMention)) {
|
||||
if (cursorPos <= atIndex + mentionPattern.length) {
|
||||
return { isActive: false, query: "", startIndex: 0 };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (query.length > 50) {
|
||||
return { isActive: false, query: "", startIndex: 0 };
|
||||
}
|
||||
|
||||
return { isActive: true, query, startIndex: atIndex };
|
||||
}
|
||||
|
||||
export function CommentComposer({
|
||||
members,
|
||||
membersLoading = false,
|
||||
placeholder = "Write a comment...",
|
||||
submitLabel = "Send",
|
||||
isSubmitting = false,
|
||||
onSubmit,
|
||||
onCancel,
|
||||
autoFocus = false,
|
||||
initialValue = "",
|
||||
}: CommentComposerProps) {
|
||||
const [displayContent, setDisplayContent] = useState(initialValue);
|
||||
const [insertedMentions, setInsertedMentions] = useState<InsertedMention[]>([]);
|
||||
const [mentionsInitialized, setMentionsInitialized] = useState(false);
|
||||
const [mentionState, setMentionState] = useState<MentionState>({
|
||||
isActive: false,
|
||||
query: "",
|
||||
startIndex: 0,
|
||||
});
|
||||
const [highlightedIndex, setHighlightedIndex] = useState(0);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const filteredMembers = mentionState.query
|
||||
? members.filter(
|
||||
(member) =>
|
||||
member.displayName?.toLowerCase().includes(mentionState.query.toLowerCase()) ||
|
||||
member.email.toLowerCase().includes(mentionState.query.toLowerCase())
|
||||
)
|
||||
: members;
|
||||
|
||||
const closeMentionPicker = useCallback(() => {
|
||||
setMentionState({ isActive: false, query: "", startIndex: 0 });
|
||||
setHighlightedIndex(0);
|
||||
}, []);
|
||||
|
||||
const insertMention = useCallback(
|
||||
(member: MemberOption) => {
|
||||
const displayName = member.displayName || member.email.split("@")[0];
|
||||
const before = displayContent.slice(0, mentionState.startIndex);
|
||||
const cursorPos = textareaRef.current?.selectionStart ?? displayContent.length;
|
||||
const after = displayContent.slice(cursorPos);
|
||||
const mentionText = `@${displayName} `;
|
||||
const newContent = before + mentionText + after;
|
||||
|
||||
setDisplayContent(newContent);
|
||||
setInsertedMentions((prev) => {
|
||||
const exists = prev.some((m) => m.id === member.id && m.displayName === displayName);
|
||||
if (exists) return prev;
|
||||
return [...prev, { id: member.id, displayName }];
|
||||
});
|
||||
closeMentionPicker();
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
if (textareaRef.current) {
|
||||
const cursorPos = before.length + mentionText.length;
|
||||
textareaRef.current.focus();
|
||||
textareaRef.current.setSelectionRange(cursorPos, cursorPos);
|
||||
}
|
||||
});
|
||||
},
|
||||
[displayContent, mentionState.startIndex, closeMentionPicker]
|
||||
);
|
||||
|
||||
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const value = e.target.value;
|
||||
const cursorPos = e.target.selectionStart;
|
||||
setDisplayContent(value);
|
||||
|
||||
const triggerResult = findMentionTrigger(value, cursorPos, insertedMentions);
|
||||
|
||||
if (triggerResult.isActive) {
|
||||
setMentionState(triggerResult);
|
||||
setHighlightedIndex(0);
|
||||
} else if (mentionState.isActive) {
|
||||
closeMentionPicker();
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (!mentionState.isActive) {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSubmit();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
switch (e.key) {
|
||||
case "ArrowDown":
|
||||
case "Tab":
|
||||
if (!e.shiftKey) {
|
||||
e.preventDefault();
|
||||
setHighlightedIndex((prev) => (prev < filteredMembers.length - 1 ? prev + 1 : 0));
|
||||
} else if (e.key === "Tab") {
|
||||
e.preventDefault();
|
||||
setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filteredMembers.length - 1));
|
||||
}
|
||||
break;
|
||||
case "ArrowUp":
|
||||
e.preventDefault();
|
||||
setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filteredMembers.length - 1));
|
||||
break;
|
||||
case "Enter":
|
||||
e.preventDefault();
|
||||
if (filteredMembers[highlightedIndex]) {
|
||||
insertMention(filteredMembers[highlightedIndex]);
|
||||
}
|
||||
break;
|
||||
case "Escape":
|
||||
e.preventDefault();
|
||||
closeMentionPicker();
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = () => {
|
||||
const trimmed = displayContent.trim();
|
||||
if (!trimmed || isSubmitting) return;
|
||||
|
||||
const dataContent = convertDisplayToData(trimmed, insertedMentions);
|
||||
onSubmit(dataContent);
|
||||
setDisplayContent("");
|
||||
setInsertedMentions([]);
|
||||
};
|
||||
|
||||
// Pre-populate insertedMentions from initialValue when members are loaded
|
||||
useEffect(() => {
|
||||
if (mentionsInitialized || !initialValue || members.length === 0) return;
|
||||
|
||||
const mentionPattern = /@([^\s@]+(?:\s+[^\s@]+)*?)(?=\s|$|[.,!?;:]|@)/g;
|
||||
const foundMentions: InsertedMention[] = [];
|
||||
let match: RegExpExecArray | null;
|
||||
|
||||
while ((match = mentionPattern.exec(initialValue)) !== null) {
|
||||
const displayName = match[1];
|
||||
const member = members.find(
|
||||
(m) => m.displayName === displayName || m.email.split("@")[0] === displayName
|
||||
);
|
||||
if (member) {
|
||||
const exists = foundMentions.some((m) => m.id === member.id);
|
||||
if (!exists) {
|
||||
foundMentions.push({ id: member.id, displayName });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (foundMentions.length > 0) {
|
||||
setInsertedMentions(foundMentions);
|
||||
}
|
||||
setMentionsInitialized(true);
|
||||
}, [initialValue, members, mentionsInitialized]);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoFocus && textareaRef.current) {
|
||||
textareaRef.current.focus();
|
||||
}
|
||||
}, [autoFocus]);
|
||||
|
||||
const canSubmit = displayContent.trim().length > 0 && !isSubmitting;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Popover
|
||||
open={mentionState.isActive}
|
||||
onOpenChange={(open) => !open && closeMentionPicker()}
|
||||
modal={false}
|
||||
>
|
||||
<PopoverAnchor asChild>
|
||||
<Textarea
|
||||
ref={textareaRef}
|
||||
value={displayContent}
|
||||
onChange={handleInputChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
className="min-h-[80px] resize-none"
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
</PopoverAnchor>
|
||||
<PopoverContent
|
||||
side="top"
|
||||
align="start"
|
||||
sideOffset={4}
|
||||
collisionPadding={8}
|
||||
className="w-72 p-0"
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<MemberMentionPicker
|
||||
members={members}
|
||||
query={mentionState.query}
|
||||
highlightedIndex={highlightedIndex}
|
||||
isLoading={membersLoading}
|
||||
onSelect={insertMention}
|
||||
onHighlightChange={setHighlightedIndex}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{onCancel && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={onCancel}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<X className="mr-1 size-4" />
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
type="button"
|
||||
size="sm"
|
||||
onClick={handleSubmit}
|
||||
disabled={!canSubmit}
|
||||
className={cn(!canSubmit && "opacity-50")}
|
||||
>
|
||||
<Send className="mr-1 size-4" />
|
||||
{submitLabel}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue