mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
Merge pull request #1067 from MODSetter/dev
feat: OneDrive & Dropbox connectors, desktop quick-ask, prompt library, and UX improvements
This commit is contained in:
commit
deccbca506
248 changed files with 17173 additions and 3487 deletions
|
|
@ -23,7 +23,7 @@
|
|||
# SurfSense
|
||||
Conecta cualquier LLM a tus fuentes de conocimiento internas y chatea con él en tiempo real junto a tu equipo. Alternativa de código abierto a NotebookLM, Perplexity y Glean.
|
||||
|
||||
SurfSense es un agente de investigación de IA altamente personalizable, conectado a fuentes externas como motores de búsqueda (SearxNG, Tavily, LinkUp), Google Drive, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian y más por venir.
|
||||
SurfSense es un agente de investigación de IA altamente personalizable, conectado a fuentes externas como motores de búsqueda (SearxNG, Tavily, LinkUp), Google Drive, OneDrive, Dropbox, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian y más por venir.
|
||||
|
||||
|
||||
|
||||
|
|
@ -149,14 +149,14 @@ Para Docker Compose, instalación manual y otras opciones de despliegue, consult
|
|||
| Generación de Presentaciones | Crea presentaciones editables basadas en diapositivas |
|
||||
| Generación de Podcasts | Podcast de 3 min en menos de 20 segundos; múltiples proveedores TTS (OpenAI, Azure, Kokoro) |
|
||||
| Extensión de Navegador | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
|
||||
| 25+ Conectores | Motores de búsqueda, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord y [más](#fuentes-externas) |
|
||||
| 27+ Conectores | Motores de búsqueda, Google Drive, OneDrive, Dropbox, Slack, Teams, Jira, Notion, GitHub, Discord y [más](#fuentes-externas) |
|
||||
| Auto-Hospedable | Código abierto, Docker en un solo comando o Docker Compose completo para producción |
|
||||
|
||||
<details>
|
||||
<summary><b>Lista completa de Fuentes Externas</b></summary>
|
||||
<a id="fuentes-externas"></a>
|
||||
|
||||
Motores de Búsqueda (Tavily, LinkUp) · SearxNG · Google Drive · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Videos de YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, y más por venir.
|
||||
Motores de Búsqueda (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Videos de YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, y más por venir.
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
# SurfSense
|
||||
किसी भी LLM को अपने आंतरिक ज्ञान स्रोतों से जोड़ें और अपनी टीम के साथ रीयल-टाइम में चैट करें। NotebookLM, Perplexity और Glean का ओपन सोर्स विकल्प।
|
||||
|
||||
SurfSense एक अत्यधिक अनुकूलन योग्य AI शोध एजेंट है, जो बाहरी स्रोतों से जुड़ा है जैसे सर्च इंजन (SearxNG, Tavily, LinkUp), Google Drive, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian और भी बहुत कुछ आने वाला है।
|
||||
SurfSense एक अत्यधिक अनुकूलन योग्य AI शोध एजेंट है, जो बाहरी स्रोतों से जुड़ा है जैसे सर्च इंजन (SearxNG, Tavily, LinkUp), Google Drive, OneDrive, Dropbox, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian और भी बहुत कुछ आने वाला है।
|
||||
|
||||
|
||||
|
||||
|
|
@ -149,14 +149,14 @@ Docker Compose, मैनुअल इंस्टॉलेशन और अन
|
|||
| प्रेजेंटेशन जनरेशन | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं |
|
||||
| पॉडकास्ट जनरेशन | 20 सेकंड से कम में 3 मिनट का पॉडकास्ट; कई TTS प्रदाता (OpenAI, Azure, Kokoro) |
|
||||
| ब्राउज़र एक्सटेंशन | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
|
||||
| 25+ कनेक्टर्स | सर्च इंजन, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord और [अधिक](#बाहरी-स्रोत) |
|
||||
| 27+ कनेक्टर्स | सर्च इंजन, Google Drive, OneDrive, Dropbox, Slack, Teams, Jira, Notion, GitHub, Discord और [अधिक](#बाहरी-स्रोत) |
|
||||
| सेल्फ-होस्ट करने योग्य | ओपन सोर्स, Docker एक कमांड या प्रोडक्शन के लिए पूर्ण Docker Compose |
|
||||
|
||||
<details>
|
||||
<summary><b>बाहरी स्रोतों की पूरी सूची</b></summary>
|
||||
<a id="बाहरी-स्रोत"></a>
|
||||
|
||||
सर्च इंजन (Tavily, LinkUp) · SearxNG · Google Drive · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube वीडियो · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, और भी बहुत कुछ आने वाला है।
|
||||
सर्च इंजन (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube वीडियो · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, और भी बहुत कुछ आने वाला है।
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
# SurfSense
|
||||
Connect any LLM to your internal knowledge sources and chat with it in real time alongside your team. OSS alternative to NotebookLM, Perplexity, and Glean.
|
||||
|
||||
SurfSense is a highly customizable AI research agent, connected to external sources such as Search Engines (SearxNG, Tavily, LinkUp), Google Drive, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian and more to come.
|
||||
SurfSense is a highly customizable AI research agent, connected to external sources such as Search Engines (SearxNG, Tavily, LinkUp), Google Drive, OneDrive, Dropbox, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian and more to come.
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,14 +150,14 @@ For Docker Compose, manual installation, and other deployment options, see the [
|
|||
| Presentation Generation | Create editable, slide based presentations |
|
||||
| Podcast Generation | 3 min podcast in under 20 seconds; multiple TTS providers (OpenAI, Azure, Kokoro) |
|
||||
| Browser Extension | Cross browser extension to save any webpage, including auth protected pages |
|
||||
| 25+ Connectors | Search Engines, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord & [more](#external-sources) |
|
||||
| 27+ Connectors | Search Engines, Google Drive, OneDrive, Dropbox, Slack, Teams, Jira, Notion, GitHub, Discord & [more](#external-sources) |
|
||||
| Self Hostable | Open source, Docker one liner or full Docker Compose for production |
|
||||
|
||||
<details>
|
||||
<summary><b>Full list of External Sources</b></summary>
|
||||
<a id="external-sources"></a>
|
||||
|
||||
Search Engines (Tavily, LinkUp) · SearxNG · Google Drive · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube Videos · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, and more to come.
|
||||
Search Engines (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube Videos · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, and more to come.
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
# SurfSense
|
||||
Conecte qualquer LLM às suas fontes de conhecimento internas e converse com ele em tempo real junto com sua equipe. Alternativa de código aberto ao NotebookLM, Perplexity e Glean.
|
||||
|
||||
SurfSense é um agente de pesquisa de IA altamente personalizável, conectado a fontes externas como mecanismos de busca (SearxNG, Tavily, LinkUp), Google Drive, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian e mais por vir.
|
||||
SurfSense é um agente de pesquisa de IA altamente personalizável, conectado a fontes externas como mecanismos de busca (SearxNG, Tavily, LinkUp), Google Drive, OneDrive, Dropbox, Slack, Microsoft Teams, Linear, Jira, ClickUp, Confluence, BookStack, Gmail, Notion, YouTube, GitHub, Discord, Airtable, Google Calendar, Luma, Circleback, Elasticsearch, Obsidian e mais por vir.
|
||||
|
||||
|
||||
|
||||
|
|
@ -149,14 +149,14 @@ Para Docker Compose, instalação manual e outras opções de implantação, con
|
|||
| Geração de Apresentações | Cria apresentações editáveis baseadas em slides |
|
||||
| Geração de Podcasts | Podcast de 3 min em menos de 20 segundos; múltiplos provedores TTS (OpenAI, Azure, Kokoro) |
|
||||
| Extensão de Navegador | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
|
||||
| 25+ Conectores | Mecanismos de busca, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord e [mais](#fontes-externas) |
|
||||
| 27+ Conectores | Mecanismos de busca, Google Drive, OneDrive, Dropbox, Slack, Teams, Jira, Notion, GitHub, Discord e [mais](#fontes-externas) |
|
||||
| Auto-Hospedável | Código aberto, Docker em um único comando ou Docker Compose completo para produção |
|
||||
|
||||
<details>
|
||||
<summary><b>Lista completa de Fontes Externas</b></summary>
|
||||
<a id="fontes-externas"></a>
|
||||
|
||||
Mecanismos de Busca (Tavily, LinkUp) · SearxNG · Google Drive · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Vídeos do YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, e mais por vir.
|
||||
Mecanismos de Busca (Tavily, LinkUp) · SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · Vídeos do YouTube · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian, e mais por vir.
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
# SurfSense
|
||||
将任何 LLM 连接到您的内部知识源,并与团队成员实时聊天。NotebookLM、Perplexity 和 Glean 的开源替代方案。
|
||||
|
||||
SurfSense 是一个高度可定制的 AI 研究助手,可以连接外部数据源,如搜索引擎(SearxNG、Tavily、LinkUp)、Google Drive、Slack、Microsoft Teams、Linear、Jira、ClickUp、Confluence、BookStack、Gmail、Notion、YouTube、GitHub、Discord、Airtable、Google Calendar、Luma、Circleback、Elasticsearch、Obsidian 等,未来还会支持更多。
|
||||
SurfSense 是一个高度可定制的 AI 研究助手,可以连接外部数据源,如搜索引擎(SearxNG、Tavily、LinkUp)、Google Drive、OneDrive、Dropbox、Slack、Microsoft Teams、Linear、Jira、ClickUp、Confluence、BookStack、Gmail、Notion、YouTube、GitHub、Discord、Airtable、Google Calendar、Luma、Circleback、Elasticsearch、Obsidian 等,未来还会支持更多。
|
||||
|
||||
|
||||
|
||||
|
|
@ -149,14 +149,14 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
| 演示文稿生成 | 创建可编辑的幻灯片式演示文稿 |
|
||||
| 播客生成 | 20 秒内生成 3 分钟播客;多种 TTS 提供商(OpenAI、Azure、Kokoro) |
|
||||
| 浏览器扩展 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
|
||||
| 25+ 连接器 | 搜索引擎、Google Drive、Slack、Teams、Jira、Notion、GitHub、Discord 等[更多](#外部数据源) |
|
||||
| 27+ 连接器 | 搜索引擎、Google Drive、OneDrive、Dropbox、Slack、Teams、Jira、Notion、GitHub、Discord 等[更多](#外部数据源) |
|
||||
| 可自托管 | 开源,Docker 一行命令或完整 Docker Compose 用于生产环境 |
|
||||
|
||||
<details>
|
||||
<summary><b>外部数据源完整列表</b></summary>
|
||||
<a id="外部数据源"></a>
|
||||
|
||||
搜索引擎(Tavily、LinkUp)· SearxNG · Google Drive · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube 视频 · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian,更多即将推出。
|
||||
搜索引擎(Tavily、LinkUp)· SearxNG · Google Drive · OneDrive · Dropbox · Slack · Microsoft Teams · Linear · Jira · ClickUp · Confluence · BookStack · Notion · Gmail · YouTube 视频 · GitHub · Discord · Airtable · Google Calendar · Luma · Circleback · Elasticsearch · Obsidian,更多即将推出。
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,20 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# Supports TLS: rediss://:password@host:6380/0
|
||||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Stripe (pay-as-you-go page packs — disabled by default)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Set TRUE to allow users to buy additional page packs via Stripe Checkout
|
||||
STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||
# STRIPE_SECRET_KEY=sk_test_...
|
||||
# STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
# STRIPE_PRICE_ID=price_...
|
||||
# STRIPE_PAGES_PER_UNIT=1000
|
||||
# STRIPE_RECONCILIATION_INTERVAL=10m
|
||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
@ -203,10 +217,16 @@ STT_SERVICE=local/base
|
|||
# AIRTABLE_CLIENT_SECRET=
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
||||
|
||||
# -- Microsoft Teams --
|
||||
# TEAMS_CLIENT_ID=
|
||||
# TEAMS_CLIENT_SECRET=
|
||||
# -- Microsoft OAuth (Teams & OneDrive) --
|
||||
# MICROSOFT_CLIENT_ID=
|
||||
# MICROSOFT_CLIENT_SECRET=
|
||||
# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
# ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback
|
||||
|
||||
# -- Dropbox --
|
||||
# DROPBOX_APP_KEY=
|
||||
# DROPBOX_APP_SECRET=
|
||||
# DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback
|
||||
|
||||
# -- Composio --
|
||||
# COMPOSIO_API_KEY=
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ SurfSense 现已支持以下国产 LLM:
|
|||
|
||||
1. 登录 SurfSense Dashboard
|
||||
2. 进入 **Settings** → **API Keys** (或 **LLM Configurations**)
|
||||
3. 点击 **Add New Configuration**
|
||||
3. 点击 **Add LLM Model**
|
||||
4. 从 **Provider** 下拉菜单中选择你的国产 LLM 提供商
|
||||
5. 填写必填字段(见下方各提供商详细配置)
|
||||
6. 点击 **Save**
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ REDIS_APP_URL=redis://localhost:6379/0
|
|||
# # Run every 2 hours
|
||||
# SCHEDULE_CHECKER_INTERVAL=2h
|
||||
SCHEDULE_CHECKER_INTERVAL=5m
|
||||
# How often the Stripe reconciliation beat task runs
|
||||
STRIPE_RECONCILIATION_INTERVAL=10m
|
||||
|
||||
SECRET_KEY=SECRET
|
||||
|
||||
|
|
@ -42,6 +44,20 @@ SECRET_KEY=SECRET
|
|||
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# Stripe Checkout for pay-as-you-go page packs
|
||||
# Configure STRIPE_PRICE_ID to point at your 1,000-page price in Stripe.
|
||||
# Pages granted per purchase = quantity * STRIPE_PAGES_PER_UNIT.
|
||||
STRIPE_SECRET_KEY=sk_test_...
|
||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
STRIPE_PRICE_ID=price_...
|
||||
STRIPE_PAGES_PER_UNIT=1000
|
||||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
# Max pending purchases to check per reconciliation run
|
||||
STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||
|
||||
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
|
||||
|
|
@ -74,7 +90,7 @@ DISCORD_CLIENT_SECRET=your_discord_client_secret_here
|
|||
DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback
|
||||
DISCORD_BOT_TOKEN=your_bot_token_from_developer_portal
|
||||
|
||||
# Atlassian OAuth Configuration
|
||||
# Atlassian OAuth Configuration (Jira & Confluence)
|
||||
ATLASSIAN_CLIENT_ID=your_atlassian_client_id_here
|
||||
ATLASSIAN_CLIENT_SECRET=your_atlassian_client_secret_here
|
||||
JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback
|
||||
|
|
@ -95,10 +111,16 @@ SLACK_CLIENT_ID=your_slack_client_id_here
|
|||
SLACK_CLIENT_SECRET=your_slack_client_secret_here
|
||||
SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback
|
||||
|
||||
# Teams OAuth Configuration
|
||||
TEAMS_CLIENT_ID=your_teams_client_id_here
|
||||
TEAMS_CLIENT_SECRET=your_teams_client_secret_here
|
||||
# Microsoft OAuth (Teams & OneDrive)
|
||||
MICROSOFT_CLIENT_ID=your_microsoft_client_id_here
|
||||
MICROSOFT_CLIENT_SECRET=your_microsoft_client_secret_here
|
||||
TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback
|
||||
|
||||
# Dropbox Connector
|
||||
DROPBOX_APP_KEY=your_dropbox_app_key_here
|
||||
DROPBOX_APP_SECRET=your_dropbox_app_secret_here
|
||||
DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback
|
||||
|
||||
# Composio Connector
|
||||
# NOTE: Disable "Mask Connected Account Secrets" in Composio dashboard (Settings → Project Settings) for Google indexing to work.
|
||||
|
|
@ -143,6 +165,14 @@ STT_SERVICE=local/base
|
|||
# STT_SERVICE_API_KEY=""
|
||||
# STT_SERVICE_API_BASE=
|
||||
|
||||
# Video presentation defaults
|
||||
# Maximum number of generated slides for a single video presentation.
|
||||
VIDEO_PRESENTATION_MAX_SLIDES=30
|
||||
# Frames per second used for slide timing calculations.
|
||||
VIDEO_PRESENTATION_FPS=30
|
||||
# Minimum duration per slide when audio is missing or very short.
|
||||
VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||
|
||||
|
||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
"""Add OneDrive connector enums
|
||||
|
||||
Revision ID: 110
|
||||
Revises: 109
|
||||
Create Date: 2026-03-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "110"
|
||||
down_revision: str | None = "109"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'searchsourceconnectortype' AND e.enumlabel = 'ONEDRIVE_CONNECTOR'
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'ONEDRIVE_CONNECTOR';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'ONEDRIVE_FILE'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'ONEDRIVE_FILE';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
52
surfsense_backend/alembic/versions/111_add_prompts_table.py
Normal file
52
surfsense_backend/alembic/versions/111_add_prompts_table.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""add prompts table
|
||||
|
||||
Revision ID: 111
|
||||
Revises: 110
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "111"
|
||||
down_revision: str | None = "110"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_type WHERE typname = 'prompt_mode'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.execute("CREATE TYPE prompt_mode AS ENUM ('transform', 'explore')")
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT 1 FROM information_schema.tables WHERE table_name = 'prompts'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.execute("""
|
||||
CREATE TABLE prompts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
name VARCHAR(200) NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
mode prompt_mode NOT NULL,
|
||||
icon VARCHAR(50),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
op.execute("CREATE INDEX ix_prompts_user_id ON prompts (user_id)")
|
||||
op.execute(
|
||||
"CREATE INDEX ix_prompts_search_space_id ON prompts (search_space_id)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS prompts")
|
||||
op.execute("DROP TYPE IF EXISTS prompt_mode")
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""Add Dropbox connector enums
|
||||
|
||||
Revision ID: 112
|
||||
Revises: 111
|
||||
Create Date: 2026-03-30 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "112"
|
||||
down_revision: str | None = "111"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'searchsourceconnectortype' AND e.enumlabel = 'DROPBOX_CONNECTOR'
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'DROPBOX_CONNECTOR';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'DROPBOX_FILE'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'DROPBOX_FILE';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""add prompt library schema: is_public, default_prompt_slug, version, drop icon
|
||||
|
||||
Revision ID: 113
|
||||
Revises: 112
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "113"
|
||||
down_revision: str | None = "112"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
||||
" is_public BOOLEAN NOT NULL DEFAULT false"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_prompts_is_public"
|
||||
" ON prompts (is_public) WHERE is_public = true"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS default_prompt_slug VARCHAR(100)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_prompts_default_prompt_slug"
|
||||
" ON prompts (default_prompt_slug)"
|
||||
)
|
||||
conn = op.get_bind()
|
||||
exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_constraint WHERE conname = 'uq_prompt_user_default_slug'"
|
||||
)
|
||||
).scalar()
|
||||
if not exists:
|
||||
op.execute(
|
||||
"ALTER TABLE prompts ADD CONSTRAINT uq_prompt_user_default_slug"
|
||||
" UNIQUE (user_id, default_prompt_slug)"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
||||
" version INTEGER NOT NULL DEFAULT 1"
|
||||
)
|
||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)")
|
||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS version")
|
||||
op.execute(
|
||||
"ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_default_slug"
|
||||
)
|
||||
op.execute("DROP INDEX IF EXISTS ix_prompts_default_prompt_slug")
|
||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug")
|
||||
op.execute("DROP INDEX IF EXISTS ix_prompts_is_public")
|
||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS is_public")
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""seed default prompts for all existing users
|
||||
|
||||
Revision ID: 114
|
||||
Revises: 113
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "114"
|
||||
down_revision: str | None = "113"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO prompts
|
||||
(user_id, default_prompt_slug, name, prompt, mode, version, is_public, created_at)
|
||||
SELECT u.id, d.slug, d.name, d.prompt, d.mode::prompt_mode, 1, false, now()
|
||||
FROM "user" u
|
||||
CROSS JOIN (VALUES
|
||||
('fix-grammar', 'Fix grammar', 'Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}', 'transform'),
|
||||
('make-shorter', 'Make shorter', 'Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}', 'transform'),
|
||||
('translate', 'Translate', 'Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}', 'transform'),
|
||||
('rewrite', 'Rewrite', 'Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}', 'transform'),
|
||||
('summarize', 'Summarize', 'Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}', 'transform'),
|
||||
('explain', 'Explain', 'Explain the following text in simple terms:\n\n{selection}', 'explore'),
|
||||
('ask-knowledge-base','Ask my knowledge base', 'Search my knowledge base for information related to:\n\n{selection}', 'explore'),
|
||||
('look-up-web', 'Look up on the web', 'Search the web for information about:\n\n{selection}', 'explore')
|
||||
) AS d(slug, name, prompt, mode)
|
||||
ON CONFLICT (user_id, default_prompt_slug) DO NOTHING
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DELETE FROM prompts WHERE default_prompt_slug IS NOT NULL")
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""add page purchases table for Stripe-backed page packs
|
||||
|
||||
Revision ID: 115
|
||||
Revises: 114
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "115"
|
||||
down_revision: str | None = "114"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create page_purchases table and supporting enum/indexes."""
|
||||
conn = op.get_bind()
|
||||
|
||||
enum_exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_type WHERE typname = 'pagepurchasestatus'")
|
||||
).fetchone()
|
||||
if not enum_exists:
|
||||
page_purchase_status_enum = postgresql.ENUM(
|
||||
"PENDING",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
name="pagepurchasestatus",
|
||||
create_type=False,
|
||||
)
|
||||
page_purchase_status_enum.create(conn, checkfirst=True)
|
||||
|
||||
table_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.tables WHERE table_name = 'page_purchases'"
|
||||
)
|
||||
).fetchone()
|
||||
if not table_exists:
|
||||
op.create_table(
|
||||
"page_purchases",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"stripe_checkout_session_id",
|
||||
sa.String(length=255),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"stripe_payment_intent_id",
|
||||
sa.String(length=255),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("quantity", sa.Integer(), nullable=False),
|
||||
sa.Column("pages_granted", sa.Integer(), nullable=False),
|
||||
sa.Column("amount_total", sa.Integer(), nullable=True),
|
||||
sa.Column("currency", sa.String(length=10), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
postgresql.ENUM(
|
||||
"PENDING",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
name="pagepurchasestatus",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=False,
|
||||
server_default=sa.text("'PENDING'::pagepurchasestatus"),
|
||||
),
|
||||
sa.Column("completed_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"stripe_checkout_session_id",
|
||||
name="uq_page_purchases_stripe_checkout_session_id",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_page_purchases_user_id ON page_purchases (user_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS ix_page_purchases_stripe_checkout_session_id "
|
||||
"ON page_purchases (stripe_checkout_session_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_page_purchases_stripe_payment_intent_id "
|
||||
"ON page_purchases (stripe_payment_intent_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_page_purchases_status ON page_purchases (status)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_page_purchases_created_at "
|
||||
"ON page_purchases (created_at)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop page_purchases table and enum."""
|
||||
op.execute("DROP INDEX IF EXISTS ix_page_purchases_created_at")
|
||||
op.execute("DROP INDEX IF EXISTS ix_page_purchases_status")
|
||||
op.execute("DROP INDEX IF EXISTS ix_page_purchases_stripe_payment_intent_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_page_purchases_stripe_checkout_session_id")
|
||||
op.execute("DROP INDEX IF EXISTS ix_page_purchases_user_id")
|
||||
op.execute("DROP TABLE IF EXISTS page_purchases")
|
||||
postgresql.ENUM(name="pagepurchasestatus").drop(op.get_bind(), checkfirst=True)
|
||||
|
|
@ -84,6 +84,8 @@ _CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
|
||||
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
|
||||
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
|
||||
"DROPBOX_CONNECTOR": "DROPBOX_FILE", # Connector type differs from document type
|
||||
"ONEDRIVE_CONNECTOR": "ONEDRIVE_FILE", # Connector type differs from document type
|
||||
# Composio connectors (unified to native document types).
|
||||
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
|
||||
|
|
@ -316,6 +318,18 @@ async def create_surfsense_deep_agent(
|
|||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
has_dropbox_connector = (
|
||||
available_connectors is not None and "DROPBOX_FILE" in available_connectors
|
||||
)
|
||||
if not has_dropbox_connector:
|
||||
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
|
||||
|
||||
has_onedrive_connector = (
|
||||
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_onedrive_connector:
|
||||
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
|
|
@ -433,6 +447,7 @@ async def create_surfsense_deep_agent(
|
|||
deepagent_middleware = [
|
||||
TodoListMiddleware(),
|
||||
KnowledgeBaseSearchMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ _HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
|||
"trash_gmail_email": "email_subject_or_id",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"delete_google_drive_file": "file_name",
|
||||
"delete_onedrive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
|
|
|
|||
|
|
@ -15,14 +15,19 @@ import logging
|
|||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from litellm import token_counter
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.utils.document_converters import embed_texts
|
||||
|
|
@ -32,6 +37,23 @@ logger = logging.getLogger(__name__)
|
|||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class KBSearchPlan(BaseModel):
|
||||
"""Structured internal plan for KB retrieval."""
|
||||
|
||||
optimized_query: str = Field(
|
||||
min_length=1,
|
||||
description="Optimized retrieval query preserving the user's intent.",
|
||||
)
|
||||
start_date: str | None = Field(
|
||||
default=None,
|
||||
description="Optional ISO start date or datetime for KB search filtering.",
|
||||
)
|
||||
end_date: str | None = Field(
|
||||
default=None,
|
||||
description="Optional ISO end date or datetime for KB search filtering.",
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
"""Extract plain text from a message content."""
|
||||
content = getattr(message, "content", "")
|
||||
|
|
@ -61,6 +83,212 @@ def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
|||
return name
|
||||
|
||||
|
||||
def _render_recent_conversation(
|
||||
messages: Sequence[BaseMessage],
|
||||
*,
|
||||
llm: BaseChatModel | None = None,
|
||||
user_text: str = "",
|
||||
max_messages: int = 6,
|
||||
) -> str:
|
||||
"""Render recent dialogue for internal planning under a token budget.
|
||||
|
||||
Prefers the latest messages and uses the project's existing model-aware
|
||||
token budgeting hooks when available on the LLM (`_count_tokens`,
|
||||
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
|
||||
if token counting is unavailable.
|
||||
"""
|
||||
rendered: list[tuple[str, str]] = []
|
||||
for message in messages:
|
||||
role: str | None = None
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
if getattr(message, "tool_calls", None):
|
||||
continue
|
||||
role = "assistant"
|
||||
else:
|
||||
continue
|
||||
|
||||
text = _extract_text_from_message(message).strip()
|
||||
if not text:
|
||||
continue
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
rendered.append((role, text))
|
||||
|
||||
if not rendered:
|
||||
return ""
|
||||
|
||||
# Exclude the latest user message from "recent conversation" because it is
|
||||
# already passed separately as "Latest user message" in the planner prompt.
|
||||
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
||||
rendered = rendered[:-1]
|
||||
|
||||
if not rendered:
|
||||
return ""
|
||||
|
||||
def _legacy_render() -> str:
|
||||
legacy_lines: list[str] = []
|
||||
for role, text in rendered[-max_messages:]:
|
||||
clipped = text[:400].rstrip() + "..." if len(text) > 400 else text
|
||||
legacy_lines.append(f"{role}: {clipped}")
|
||||
return "\n".join(legacy_lines)
|
||||
|
||||
def _count_prompt_tokens(conversation_text: str) -> int | None:
|
||||
prompt = _build_kb_planner_prompt(
|
||||
recent_conversation=conversation_text or "(none)",
|
||||
user_text=user_text,
|
||||
)
|
||||
message_payload = [{"role": "user", "content": prompt}]
|
||||
|
||||
count_fn = getattr(llm, "_count_tokens", None) if llm is not None else None
|
||||
if callable(count_fn):
|
||||
try:
|
||||
return count_fn(message_payload)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
profile = getattr(llm, "profile", None) if llm is not None else None
|
||||
model_names: list[str] = []
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
model_names.extend(
|
||||
name for name in tcms if isinstance(name, str) and name
|
||||
)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||
model_names.append(tcm)
|
||||
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||
if not isinstance(model_name, str) or not model_name:
|
||||
return None
|
||||
try:
|
||||
return token_counter(messages=message_payload, model=model_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
get_max_input_tokens = getattr(llm, "_get_max_input_tokens", None) if llm else None
|
||||
if callable(get_max_input_tokens):
|
||||
try:
|
||||
max_input_tokens = int(get_max_input_tokens())
|
||||
except Exception:
|
||||
max_input_tokens = None
|
||||
else:
|
||||
profile = getattr(llm, "profile", None) if llm is not None else None
|
||||
max_input_tokens = (
|
||||
profile.get("max_input_tokens")
|
||||
if isinstance(profile, dict)
|
||||
and isinstance(profile.get("max_input_tokens"), int)
|
||||
else None
|
||||
)
|
||||
|
||||
if not isinstance(max_input_tokens, int) or max_input_tokens <= 0:
|
||||
return _legacy_render()
|
||||
|
||||
output_reserve = min(max(int(max_input_tokens * 0.02), 256), 1024)
|
||||
budget = max_input_tokens - output_reserve
|
||||
if budget <= 0:
|
||||
return _legacy_render()
|
||||
|
||||
selected_lines: list[str] = []
|
||||
for role, text in reversed(rendered):
|
||||
candidate_line = f"{role}: {text}"
|
||||
candidate_lines = [candidate_line, *selected_lines]
|
||||
candidate_conversation = "\n".join(candidate_lines)
|
||||
token_count = _count_prompt_tokens(candidate_conversation)
|
||||
if token_count is None:
|
||||
return _legacy_render()
|
||||
if token_count <= budget:
|
||||
selected_lines = candidate_lines
|
||||
continue
|
||||
|
||||
# If the full message does not fit, keep as much of this most-recent
|
||||
# older message as possible via binary search.
|
||||
lo, hi = 1, len(text)
|
||||
best_line: str | None = None
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
clipped_text = text[:mid].rstrip() + "..."
|
||||
clipped_line = f"{role}: {clipped_text}"
|
||||
clipped_conversation = "\n".join([clipped_line, *selected_lines])
|
||||
clipped_tokens = _count_prompt_tokens(clipped_conversation)
|
||||
if clipped_tokens is None:
|
||||
break
|
||||
if clipped_tokens <= budget:
|
||||
best_line = clipped_line
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
|
||||
if best_line is not None:
|
||||
selected_lines = [best_line, *selected_lines]
|
||||
break
|
||||
|
||||
if not selected_lines:
|
||||
return _legacy_render()
|
||||
|
||||
return "\n".join(selected_lines)
|
||||
|
||||
|
||||
def _build_kb_planner_prompt(
|
||||
*,
|
||||
recent_conversation: str,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
"""Build a compact internal prompt for KB query rewriting and date scoping."""
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
return (
|
||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||
"Return JSON only with this exact shape:\n"
|
||||
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
|
||||
"Rules:\n"
|
||||
"- Preserve the user's intent.\n"
|
||||
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
|
||||
"- Keep the query concise and retrieval-focused.\n"
|
||||
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
|
||||
"- If you use date filters, prefer returning both bounds.\n"
|
||||
"- If no date filter is useful, return null for both dates.\n"
|
||||
"- Do not include markdown, prose, or explanations.\n\n"
|
||||
f"Today's UTC date: {today}\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
f"Latest user message:\n{user_text}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_payload(text: str) -> str:
|
||||
"""Extract a JSON object from a raw LLM response."""
|
||||
stripped = text.strip()
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||
if fenced:
|
||||
return fenced.group(1)
|
||||
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start : end + 1]
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
||||
"""Parse and validate the planner's JSON response."""
|
||||
payload = json.loads(_extract_json_payload(response_text))
|
||||
return KBSearchPlan.model_validate(payload)
|
||||
|
||||
|
||||
def _normalize_optional_date_range(
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""Normalize optional planner dates into a UTC datetime range."""
|
||||
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
||||
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
||||
|
||||
if parsed_start is None and parsed_end is None:
|
||||
return None, None
|
||||
|
||||
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
|
||||
return resolved_start, resolved_end
|
||||
|
||||
|
||||
def _build_document_xml(
|
||||
document: dict[str, Any],
|
||||
matched_chunk_ids: set[int] | None = None,
|
||||
|
|
@ -264,6 +492,8 @@ async def search_knowledge_base(
|
|||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run a single unified hybrid search against the knowledge base.
|
||||
|
||||
|
|
@ -286,6 +516,8 @@ async def search_knowledge_base(
|
|||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=doc_types,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_embedding=embedding.tolist(),
|
||||
)
|
||||
|
||||
|
|
@ -346,16 +578,71 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm: BaseChatModel | None = None,
|
||||
search_space_id: int,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
self.available_connectors = available_connectors
|
||||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
|
||||
async def _plan_search_inputs(
|
||||
self,
|
||||
*,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
) -> tuple[str, datetime | None, datetime | None]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM."""
|
||||
if self.llm is None:
|
||||
return user_text, None, None
|
||||
|
||||
recent_conversation = _render_recent_conversation(
|
||||
messages,
|
||||
llm=self.llm,
|
||||
user_text=user_text,
|
||||
)
|
||||
prompt = _build_kb_planner_prompt(
|
||||
recent_conversation=recent_conversation,
|
||||
user_text=user_text,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
|
||||
try:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
||||
optimized_query = (
|
||||
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
||||
)
|
||||
start_date, end_date = _normalize_optional_date_range(
|
||||
plan.start_date,
|
||||
plan.end_date,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
|
||||
loop.time() - t0,
|
||||
user_text[:80],
|
||||
optimized_query[:120],
|
||||
start_date.isoformat() if start_date else None,
|
||||
end_date.isoformat() if end_date else None,
|
||||
)
|
||||
return optimized_query, start_date, end_date
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"KB planner returned invalid output, using raw query: %s", exc
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("KB planner failed, using raw query: %s", exc)
|
||||
|
||||
return user_text, None, None
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
|
|
@ -388,13 +675,19 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
planned_query, start_date, end_date = await self._plan_search_inputs(
|
||||
messages=messages,
|
||||
user_text=user_text,
|
||||
)
|
||||
|
||||
search_results = await search_knowledge_base(
|
||||
query=user_text,
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
new_files = await build_scoped_filesystem(
|
||||
documents=search_results,
|
||||
|
|
@ -405,9 +698,10 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r new_files=%d total=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
planned_query[:120],
|
||||
len(new_files),
|
||||
len(new_files) + len(existing_files or {}),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.dropbox.create_file import (
|
||||
create_create_dropbox_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.dropbox.trash_file import (
|
||||
create_delete_dropbox_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_dropbox_file_tool",
|
||||
"create_delete_dropbox_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
|
||||
_FILE_TYPE_LABELS = {
|
||||
"paper": "Dropbox Paper (.paper)",
|
||||
"docx": "Word Document (.docx)",
|
||||
}
|
||||
|
||||
_SUPPORTED_TYPES = [
|
||||
{"value": "paper", "label": "Dropbox Paper (.paper)"},
|
||||
{"value": "docx", "label": "Word Document (.docx)"},
|
||||
]
|
||||
|
||||
|
||||
def _ensure_extension(name: str, file_type: str) -> str:
|
||||
"""Strip any existing extension and append the correct one."""
|
||||
stem = Path(name).stem
|
||||
ext = ".paper" if file_type == "paper" else ".docx"
|
||||
return f"{stem}{ext}"
|
||||
|
||||
|
||||
def _markdown_to_docx(markdown_text: str) -> bytes:
|
||||
"""Convert a markdown string to DOCX bytes using pypandoc."""
|
||||
import pypandoc
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".docx")
|
||||
os.close(fd)
|
||||
try:
|
||||
pypandoc.convert_text(
|
||||
markdown_text,
|
||||
"docx",
|
||||
format="gfm",
|
||||
extra_args=["--standalone"],
|
||||
outputfile=tmp_path,
|
||||
)
|
||||
with open(tmp_path, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def create_create_dropbox_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_dropbox_file(
|
||||
name: str,
|
||||
file_type: Literal["paper", "docx"] = "paper",
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new document in Dropbox.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
in Dropbox. The user MUST specify a topic before you call this tool.
|
||||
|
||||
Args:
|
||||
name: The document title (without extension).
|
||||
file_type: Either "paper" (Dropbox Paper, default) or "docx" (Word document).
|
||||
content: Optional initial content as markdown.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, file_id, name, web_url, and message.
|
||||
"""
|
||||
logger.info(
|
||||
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Dropbox accounts need re-authentication.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = DropboxClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_folder("")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{
|
||||
"folder_path": item.get("path_lower", ""),
|
||||
"name": item["name"],
|
||||
}
|
||||
for item in items
|
||||
if item.get(".tag") == "folder" and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s", cid, exc_info=True
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
"supported_types": _SUPPORTED_TYPES,
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "dropbox_file_creation",
|
||||
"action": {
|
||||
"tool": "create_dropbox_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_file_type = final_params.get("file_type", file_type)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_path = final_params.get("parent_folder_path")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_extension(final_name, final_file_type)
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid.",
|
||||
}
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
parent_path = final_parent_folder_path or ""
|
||||
file_path = (
|
||||
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
||||
)
|
||||
|
||||
if final_file_type == "paper":
|
||||
created = await client.create_paper_doc(file_path, final_content or "")
|
||||
file_id = created.get("file_id", "")
|
||||
web_url = created.get("url", "")
|
||||
else:
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
created = await client.upload_file(
|
||||
file_path, docx_bytes, mode="add", autorename=True
|
||||
)
|
||||
file_id = created.get("id", "")
|
||||
web_url = ""
|
||||
|
||||
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.dropbox import DropboxKBSyncService
|
||||
|
||||
kb_service = DropboxKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=file_id,
|
||||
file_name=final_name,
|
||||
file_path=file_path,
|
||||
web_url=web_url,
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"name": final_name,
|
||||
"web_url": web_url,
|
||||
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Dropbox file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_dropbox_file
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy import String, and_, cast, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_dropbox_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_dropbox_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a file from Dropbox.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in Dropbox.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to delete.
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: Dropbox file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["dropbox_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed Dropbox files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
file_path = meta.get("dropbox_path")
|
||||
file_id = meta.get("dropbox_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if not file_path:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File path is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox connector not found or access denied.",
|
||||
}
|
||||
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"file_path": file_path,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "dropbox_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_dropbox_file",
|
||||
"params": {
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_path = final_params.get("file_path", file_path)
|
||||
final_connector_id = final_params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
|
||||
await client.delete_file(final_file_path)
|
||||
|
||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Dropbox file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_dropbox_file
|
||||
|
|
@ -201,6 +201,8 @@ _ALL_CONNECTORS: list[str] = [
|
|||
"CRAWLED_URL",
|
||||
"CIRCLEBACK",
|
||||
"OBSIDIAN_CONNECTOR",
|
||||
"ONEDRIVE_FILE",
|
||||
"DROPBOX_FILE",
|
||||
]
|
||||
|
||||
# Human-readable descriptions for each connector type
|
||||
|
|
@ -230,6 +232,8 @@ CONNECTOR_DESCRIPTIONS: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
|
||||
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
|
||||
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
|
||||
"ONEDRIVE_FILE": "Microsoft OneDrive files and documents (personal cloud storage)",
|
||||
"DROPBOX_FILE": "Dropbox files and documents (cloud storage)",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -357,6 +361,8 @@ _INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
|||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"onedrive_file_id",
|
||||
"dropbox_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.onedrive.create_file import (
|
||||
create_create_onedrive_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.onedrive.trash_file import (
|
||||
create_delete_onedrive_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_onedrive_file_tool",
|
||||
"create_delete_onedrive_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.connectors.onedrive.client import OneDriveClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
|
||||
|
||||
def _ensure_docx_extension(name: str) -> str:
|
||||
"""Strip any existing extension and append .docx."""
|
||||
stem = Path(name).stem
|
||||
return f"{stem}.docx"
|
||||
|
||||
|
||||
def _markdown_to_docx(markdown_text: str) -> bytes:
|
||||
"""Convert a markdown string to DOCX bytes using pypandoc."""
|
||||
import pypandoc
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".docx")
|
||||
os.close(fd)
|
||||
try:
|
||||
pypandoc.convert_text(
|
||||
markdown_text,
|
||||
"docx",
|
||||
format="gfm",
|
||||
extra_args=["--standalone"],
|
||||
outputfile=tmp_path,
|
||||
)
|
||||
with open(tmp_path, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def create_create_onedrive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_onedrive_file(
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Word document (.docx) in Microsoft OneDrive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
in OneDrive. The user MUST specify a topic before you call this tool.
|
||||
|
||||
The file is always saved as a .docx Word document. Provide content as
|
||||
markdown and it will be automatically converted to a formatted Word file.
|
||||
|
||||
Args:
|
||||
name: The document title (without extension). Extension will be set to .docx automatically.
|
||||
content: Optional initial content as markdown. Will be converted to a formatted Word document.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, file_id, name, web_url, and message.
|
||||
"""
|
||||
logger.info(f"create_onedrive_file called: name='{name}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected OneDrive accounts need re-authentication.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = OneDriveClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_children("root")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{"folder_id": item["id"], "name": item["name"]}
|
||||
for item in items
|
||||
if item.get("folder") is not None
|
||||
and item.get("id")
|
||||
and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s", cid, exc_info=True
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "onedrive_file_creation",
|
||||
"action": {
|
||||
"tool": "create_onedrive_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_id = final_params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_docx_extension(final_name)
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid.",
|
||||
}
|
||||
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
|
||||
client = OneDriveClient(session=db_session, connector_id=connector.id)
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
parent_id=final_parent_folder_id,
|
||||
content=docx_bytes,
|
||||
mime_type=DOCX_MIME,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.onedrive import OneDriveKBSyncService
|
||||
|
||||
kb_service = OneDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=DOCX_MIME,
|
||||
web_url=created.get("webUrl"),
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_url": created.get("webUrl"),
|
||||
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating OneDrive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_onedrive_file
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy import String, and_, cast, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.connectors.onedrive.client import OneDriveClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_onedrive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_onedrive_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move a OneDrive file to the recycle bin.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in OneDrive.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to trash.
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both OneDrive and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: OneDrive file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["onedrive_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed OneDrive files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
file_id = meta.get("onedrive_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "OneDrive connector not found or access denied.",
|
||||
}
|
||||
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "onedrive",
|
||||
}
|
||||
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
"web_url": meta.get("web_url"),
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "onedrive_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_onedrive_file",
|
||||
"params": {
|
||||
"file_id": file_id,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected OneDrive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
client = OneDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
await client.trash_file(final_file_id)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file_name}' to the recycle bin.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting OneDrive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_onedrive_file
|
||||
|
|
@ -50,6 +50,10 @@ from .confluence import (
|
|||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .dropbox import (
|
||||
create_create_dropbox_file_tool,
|
||||
create_delete_dropbox_file_tool,
|
||||
)
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
|
|
@ -82,6 +86,10 @@ from .notion import (
|
|||
create_delete_notion_page_tool,
|
||||
create_update_notion_page_tool,
|
||||
)
|
||||
from .onedrive import (
|
||||
create_create_onedrive_file_tool,
|
||||
create_delete_onedrive_file_tool,
|
||||
)
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .report import create_generate_report_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
|
|
@ -336,6 +344,54 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# DROPBOX TOOLS - create and trash files
|
||||
# Auto-disabled when no Dropbox connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_dropbox_file",
|
||||
description="Create a new file in Dropbox",
|
||||
factory=lambda deps: create_create_dropbox_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_dropbox_file",
|
||||
description="Delete a file from Dropbox",
|
||||
factory=lambda deps: create_delete_dropbox_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# ONEDRIVE TOOLS - create and trash files
|
||||
# Auto-disabled when no OneDrive connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_onedrive_file",
|
||||
description="Create a new file in Microsoft OneDrive",
|
||||
factory=lambda deps: create_create_onedrive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_onedrive_file",
|
||||
description="Move a OneDrive file to the recycle bin",
|
||||
factory=lambda deps: create_delete_onedrive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import datetime
|
||||
|
||||
# TODO: move these to config file
|
||||
MAX_SLIDES = 5
|
||||
FPS = 30
|
||||
DEFAULT_DURATION_IN_FRAMES = 300
|
||||
from app.config import config as app_config
|
||||
|
||||
MAX_SLIDES = app_config.VIDEO_PRESENTATION_MAX_SLIDES
|
||||
FPS = app_config.VIDEO_PRESENTATION_FPS
|
||||
DEFAULT_DURATION_IN_FRAMES = app_config.VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES
|
||||
|
||||
THEME_PRESETS = [
|
||||
"TERRA",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
|
|||
# Format: "<number><unit>" where unit is 'm' (minutes) or 'h' (hours)
|
||||
# Examples: "1m" (every minute), "5m" (every 5 minutes), "1h" (every hour)
|
||||
SCHEDULE_CHECKER_INTERVAL = os.getenv("SCHEDULE_CHECKER_INTERVAL", "2m")
|
||||
STRIPE_RECONCILIATION_INTERVAL = os.getenv("STRIPE_RECONCILIATION_INTERVAL", "10m")
|
||||
|
||||
|
||||
def parse_schedule_interval(interval: str) -> dict:
|
||||
|
|
@ -68,6 +69,9 @@ def parse_schedule_interval(interval: str) -> dict:
|
|||
|
||||
# Parse the schedule interval
|
||||
schedule_params = parse_schedule_interval(SCHEDULE_CHECKER_INTERVAL)
|
||||
stripe_reconciliation_schedule_params = parse_schedule_interval(
|
||||
STRIPE_RECONCILIATION_INTERVAL
|
||||
)
|
||||
|
||||
# Create Celery app
|
||||
celery_app = Celery(
|
||||
|
|
@ -82,6 +86,7 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.schedule_checker_task",
|
||||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -169,4 +174,12 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60, # Task expires after 60 seconds if not picked up
|
||||
},
|
||||
},
|
||||
# Reconcile Stripe purchases that were paid but remained pending
|
||||
"reconcile-pending-stripe-page-purchases": {
|
||||
"task": "reconcile_pending_stripe_page_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,21 @@ class Config:
|
|||
# Backend URL to override the http to https in the OAuth redirect URI
|
||||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
|
||||
# Stripe checkout for pay-as-you-go page packs
|
||||
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
||||
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
STRIPE_PRICE_ID = os.getenv("STRIPE_PRICE_ID")
|
||||
STRIPE_PAGES_PER_UNIT = int(os.getenv("STRIPE_PAGES_PER_UNIT", "1000"))
|
||||
STRIPE_PAGE_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_PAGE_BUYING_ENABLED", "TRUE").upper() == "TRUE"
|
||||
)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES = int(
|
||||
os.getenv("STRIPE_RECONCILIATION_LOOKBACK_MINUTES", "10")
|
||||
)
|
||||
STRIPE_RECONCILIATION_BATCH_SIZE = int(
|
||||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Auth
|
||||
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
||||
REGISTRATION_ENABLED = os.getenv("REGISTRATION_ENABLED", "TRUE").upper() == "TRUE"
|
||||
|
|
@ -281,16 +296,22 @@ class Config:
|
|||
DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI")
|
||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
||||
|
||||
# Microsoft Teams OAuth
|
||||
TEAMS_CLIENT_ID = os.getenv("TEAMS_CLIENT_ID")
|
||||
TEAMS_CLIENT_SECRET = os.getenv("TEAMS_CLIENT_SECRET")
|
||||
# Microsoft OAuth (shared for Teams and OneDrive)
|
||||
MICROSOFT_CLIENT_ID = os.getenv("MICROSOFT_CLIENT_ID")
|
||||
MICROSOFT_CLIENT_SECRET = os.getenv("MICROSOFT_CLIENT_SECRET")
|
||||
TEAMS_REDIRECT_URI = os.getenv("TEAMS_REDIRECT_URI")
|
||||
ONEDRIVE_REDIRECT_URI = os.getenv("ONEDRIVE_REDIRECT_URI")
|
||||
|
||||
# ClickUp OAuth
|
||||
CLICKUP_CLIENT_ID = os.getenv("CLICKUP_CLIENT_ID")
|
||||
CLICKUP_CLIENT_SECRET = os.getenv("CLICKUP_CLIENT_SECRET")
|
||||
CLICKUP_REDIRECT_URI = os.getenv("CLICKUP_REDIRECT_URI")
|
||||
|
||||
# Dropbox OAuth
|
||||
DROPBOX_APP_KEY = os.getenv("DROPBOX_APP_KEY")
|
||||
DROPBOX_APP_SECRET = os.getenv("DROPBOX_APP_SECRET")
|
||||
DROPBOX_REDIRECT_URI = os.getenv("DROPBOX_REDIRECT_URI")
|
||||
|
||||
# Composio Configuration (for managed OAuth integrations)
|
||||
# Get your API key from https://app.composio.dev
|
||||
COMPOSIO_API_KEY = os.getenv("COMPOSIO_API_KEY")
|
||||
|
|
@ -394,6 +415,15 @@ class Config:
|
|||
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
||||
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
||||
|
||||
# Video presentation defaults
|
||||
VIDEO_PRESENTATION_MAX_SLIDES = int(
|
||||
os.getenv("VIDEO_PRESENTATION_MAX_SLIDES", "30")
|
||||
)
|
||||
VIDEO_PRESENTATION_FPS = int(os.getenv("VIDEO_PRESENTATION_FPS", "30"))
|
||||
VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES = int(
|
||||
os.getenv("VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES", "300")
|
||||
)
|
||||
|
||||
# Validation Checks
|
||||
# Check embedding dimension
|
||||
if (
|
||||
|
|
|
|||
13
surfsense_backend/app/connectors/dropbox/__init__.py
Normal file
13
surfsense_backend/app/connectors/dropbox/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Dropbox Connector Module."""
|
||||
|
||||
from .client import DropboxClient
|
||||
from .content_extractor import download_and_extract_content
|
||||
from .folder_manager import get_file_by_path, get_files_in_folder, list_folder_contents
|
||||
|
||||
__all__ = [
|
||||
"DropboxClient",
|
||||
"download_and_extract_content",
|
||||
"get_file_by_path",
|
||||
"get_files_in_folder",
|
||||
"list_folder_contents",
|
||||
]
|
||||
331
surfsense_backend/app/connectors/dropbox/client.py
Normal file
331
surfsense_backend/app/connectors/dropbox/client.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
"""Dropbox API client using Dropbox HTTP API v2."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_BASE = "https://api.dropboxapi.com"
|
||||
CONTENT_BASE = "https://content.dropboxapi.com"
|
||||
TOKEN_URL = "https://api.dropboxapi.com/oauth2/token"
|
||||
|
||||
|
||||
class DropboxClient:
|
||||
"""Client for Dropbox via the HTTP API v2."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
cfg = connector.config or {}
|
||||
is_encrypted = cfg.get("_token_encrypted", False)
|
||||
token_encryption = (
|
||||
TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None
|
||||
)
|
||||
|
||||
access_token = cfg.get("access_token", "")
|
||||
refresh_token = cfg.get("refresh_token")
|
||||
|
||||
if is_encrypted and token_encryption:
|
||||
if access_token:
|
||||
access_token = token_encryption.decrypt_token(access_token)
|
||||
if refresh_token:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
|
||||
expires_at_str = cfg.get("expires_at")
|
||||
is_expired = False
|
||||
if expires_at_str:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
is_expired = expires_at <= datetime.now(UTC)
|
||||
|
||||
if not is_expired and access_token:
|
||||
return access_token
|
||||
|
||||
if not refresh_token:
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("Dropbox token expired and no refresh token available")
|
||||
|
||||
token_data = await self._refresh_token(refresh_token)
|
||||
|
||||
new_access = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
new_expires_at = None
|
||||
if expires_in:
|
||||
new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
if token_encryption:
|
||||
cfg["access_token"] = token_encryption.encrypt_token(new_access)
|
||||
else:
|
||||
cfg["access_token"] = new_access
|
||||
|
||||
cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["_token_encrypted"] = bool(token_encryption)
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
|
||||
return new_access
|
||||
|
||||
async def _refresh_token(self, refresh_token: str) -> dict:
|
||||
data = {
|
||||
"client_id": config.DROPBOX_APP_KEY,
|
||||
"client_secret": config.DROPBOX_APP_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
error_json = resp.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise ValueError(f"Dropbox token refresh failed: {error_detail}")
|
||||
return resp.json()
|
||||
|
||||
async def _request(
|
||||
self, path: str, json_body: dict | None = None, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make an authenticated RPC request to the Dropbox API."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{API_BASE}{path}",
|
||||
headers=headers,
|
||||
json=json_body,
|
||||
timeout=60.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector:
|
||||
cfg = connector.config or {}
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("Dropbox authentication expired (401)")
|
||||
|
||||
return resp
|
||||
|
||||
async def _content_request(
|
||||
self, path: str, api_arg: dict, content: bytes | None = None, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make an authenticated content-upload/download request."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps(api_arg),
|
||||
"Content-Type": "application/octet-stream",
|
||||
}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{CONTENT_BASE}{path}",
|
||||
headers=headers,
|
||||
content=content or b"",
|
||||
timeout=120.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector:
|
||||
cfg = connector.config or {}
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("Dropbox authentication expired (401)")
|
||||
|
||||
return resp
|
||||
|
||||
async def list_folder(
|
||||
self, path: str = ""
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""List all items in a folder. Handles pagination via cursor."""
|
||||
all_items: list[dict[str, Any]] = []
|
||||
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder",
|
||||
{"path": path, "recursive": False, "include_non_downloadable_files": True},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return [], f"Failed to list folder: {resp.status_code} - {resp.text}"
|
||||
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("entries", []))
|
||||
|
||||
while data.get("has_more"):
|
||||
cursor = data["cursor"]
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder/continue", {"cursor": cursor}
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return all_items, f"Pagination failed: {resp.status_code}"
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("entries", []))
|
||||
|
||||
return all_items, None
|
||||
|
||||
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
|
||||
resp = await self._request("/2/files/get_metadata", {"path": path})
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get metadata: {resp.status_code} - {resp.text}"
|
||||
return resp.json(), None
|
||||
|
||||
async def download_file(self, path: str) -> tuple[bytes | None, str | None]:
|
||||
resp = await self._content_request("/2/files/download", {"path": path})
|
||||
if resp.status_code != 200:
|
||||
return None, f"Download failed: {resp.status_code}"
|
||||
return resp.content, None
|
||||
|
||||
async def download_file_to_disk(self, path: str, dest_path: str) -> str | None:
|
||||
"""Stream file content to disk. Returns error message on failure."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps({"path": path}),
|
||||
}
|
||||
async with (
|
||||
httpx.AsyncClient() as client,
|
||||
client.stream(
|
||||
"POST",
|
||||
f"{CONTENT_BASE}/2/files/download",
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as resp,
|
||||
):
|
||||
if resp.status_code != 200:
|
||||
return f"Download failed: {resp.status_code}"
|
||||
with open(dest_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024):
|
||||
f.write(chunk)
|
||||
return None
|
||||
|
||||
async def export_file(
|
||||
self,
|
||||
path: str,
|
||||
export_format: str | None = None,
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Export a non-downloadable file (e.g. .paper) via /2/files/export.
|
||||
|
||||
Uses the recommended new API for Paper-as-files.
|
||||
Returns (content_bytes, error_message).
|
||||
"""
|
||||
api_arg: dict[str, str] = {"path": path}
|
||||
if export_format:
|
||||
api_arg["export_format"] = export_format
|
||||
resp = await self._content_request("/2/files/export", api_arg)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Export failed: {resp.status_code} - {resp.text}"
|
||||
return resp.content, None
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
path: str,
|
||||
content: bytes,
|
||||
mode: str = "add",
|
||||
autorename: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a file to Dropbox (up to 150MB)."""
|
||||
api_arg = {"path": path, "mode": mode, "autorename": autorename}
|
||||
resp = await self._content_request("/2/files/upload", api_arg, content)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"Upload failed: {resp.status_code} - {resp.text}")
|
||||
return resp.json()
|
||||
|
||||
async def create_paper_doc(
|
||||
self, path: str, markdown_content: str
|
||||
) -> dict[str, Any]:
|
||||
"""Create a Dropbox Paper document from markdown."""
|
||||
token = await self._get_valid_token()
|
||||
api_arg = {"import_format": "markdown", "path": path}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps(api_arg),
|
||||
"Content-Type": "application/octet-stream",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{API_BASE}/2/files/paper/create",
|
||||
headers=headers,
|
||||
content=markdown_content.encode("utf-8"),
|
||||
timeout=60.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Paper doc creation failed: {resp.status_code} - {resp.text}"
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
async def delete_file(self, path: str) -> dict[str, Any]:
|
||||
"""Delete a file or folder."""
|
||||
resp = await self._request("/2/files/delete_v2", {"path": path})
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"Delete failed: {resp.status_code} - {resp.text}")
|
||||
return resp.json()
|
||||
|
||||
async def get_current_account(self) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Get current user's account info."""
|
||||
resp = await self._request("/2/users/get_current_account", None)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get account: {resp.status_code}"
|
||||
return resp.json(), None
|
||||
101
surfsense_backend/app/connectors/dropbox/content_extractor.py
Normal file
101
surfsense_backend/app/connectors/dropbox/content_extractor.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Content extraction for Dropbox files.
|
||||
|
||||
Reuses the same ETL parsing logic as OneDrive/Google Drive since file parsing
|
||||
is extension-based, not provider-specific.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from .client import DropboxClient
|
||||
from .file_types import get_extension_from_name, is_paper_file, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _export_paper_content(
|
||||
client: DropboxClient,
|
||||
file: dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[str | None, dict[str, Any], str | None]:
|
||||
"""Export a Dropbox Paper doc as markdown via ``/2/files/export``."""
|
||||
file_path_lower = file.get("path_lower", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
logger.info(f"Exporting Paper doc as markdown: {file_name}")
|
||||
|
||||
content_bytes, error = await client.export_file(
|
||||
file_path_lower, export_format="markdown"
|
||||
)
|
||||
if error:
|
||||
return None, metadata, error
|
||||
if not content_bytes:
|
||||
return None, metadata, "Export returned empty content"
|
||||
|
||||
markdown = content_bytes.decode("utf-8", errors="replace")
|
||||
metadata["exported_as"] = "markdown"
|
||||
metadata["original_type"] = "paper"
|
||||
return markdown, metadata, None
|
||||
|
||||
|
||||
async def download_and_extract_content(
|
||||
client: DropboxClient,
|
||||
file: dict[str, Any],
|
||||
) -> tuple[str | None, dict[str, Any], str | None]:
|
||||
"""Download a Dropbox file and extract its content as markdown.
|
||||
|
||||
Returns (markdown_content, dropbox_metadata, error_message).
|
||||
"""
|
||||
file_path_lower = file.get("path_lower", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
file_id = file.get("id", "")
|
||||
|
||||
if should_skip_file(file):
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name}")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"dropbox_file_id": file_id,
|
||||
"dropbox_file_name": file_name,
|
||||
"dropbox_path": file_path_lower,
|
||||
"source_connector": "dropbox",
|
||||
}
|
||||
|
||||
if "server_modified" in file:
|
||||
metadata["modified_time"] = file["server_modified"]
|
||||
if "client_modified" in file:
|
||||
metadata["created_time"] = file["client_modified"]
|
||||
if "size" in file:
|
||||
metadata["file_size"] = file["size"]
|
||||
if "content_hash" in file:
|
||||
metadata["content_hash"] = file["content_hash"]
|
||||
|
||||
if is_paper_file(file):
|
||||
return await _export_paper_content(client, file, metadata)
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
extension = get_extension_from_name(file_name) or ".bin"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||
temp_file_path = tmp.name
|
||||
|
||||
error = await client.download_file_to_disk(file_path_lower, temp_file_path)
|
||||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
from app.connectors.onedrive.content_extractor import _parse_file_to_markdown
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
return markdown, metadata, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
|
||||
return None, metadata, str(e)
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_file_path)
|
||||
58
surfsense_backend/app/connectors/dropbox/file_types.py
Normal file
58
surfsense_backend/app/connectors/dropbox/file_types.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""File type handlers for Dropbox."""
|
||||
|
||||
PAPER_EXTENSION = ".paper"
|
||||
|
||||
SKIP_EXTENSIONS: frozenset[str] = frozenset()
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
}
|
||||
|
||||
|
||||
def get_extension_from_name(name: str) -> str:
|
||||
"""Extract extension from filename."""
|
||||
dot = name.rfind(".")
|
||||
if dot > 0:
|
||||
return name[dot:]
|
||||
return ""
|
||||
|
||||
|
||||
def is_folder(item: dict) -> bool:
|
||||
return item.get(".tag") == "folder"
|
||||
|
||||
|
||||
def is_paper_file(item: dict) -> bool:
|
||||
"""Detect Dropbox Paper docs (exported via /2/files/export, not /2/files/download)."""
|
||||
name = item.get("name", "")
|
||||
ext = get_extension_from_name(name).lower()
|
||||
return ext == PAPER_EXTENSION
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
"""Skip folders and truly non-indexable files.
|
||||
|
||||
Paper docs are non-downloadable but exportable, so they are NOT skipped.
|
||||
"""
|
||||
if is_folder(item):
|
||||
return True
|
||||
if is_paper_file(item):
|
||||
return False
|
||||
if not item.get("is_downloadable", True):
|
||||
return True
|
||||
name = item.get("name", "")
|
||||
ext = get_extension_from_name(name).lower()
|
||||
return ext in SKIP_EXTENSIONS
|
||||
92
surfsense_backend/app/connectors/dropbox/folder_manager.py
Normal file
92
surfsense_backend/app/connectors/dropbox/folder_manager.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
"""Folder management for Dropbox."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .client import DropboxClient
|
||||
from .file_types import is_folder, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def list_folder_contents(
|
||||
client: DropboxClient,
|
||||
path: str = "",
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""List folders and files in a Dropbox folder.
|
||||
|
||||
Returns (items list with folders first, error message).
|
||||
"""
|
||||
try:
|
||||
items, error = await client.list_folder(path)
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = is_folder(item)
|
||||
|
||||
items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower()))
|
||||
|
||||
folder_count = sum(1 for item in items if item["isFolder"])
|
||||
file_count = len(items) - folder_count
|
||||
logger.info(
|
||||
f"Listed {len(items)} items ({folder_count} folders, {file_count} files) "
|
||||
+ (f"in folder {path}" if path else "in root")
|
||||
)
|
||||
return items, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing folder contents: {e!s}", exc_info=True)
|
||||
return [], f"Error listing folder contents: {e!s}"
|
||||
|
||||
|
||||
async def get_files_in_folder(
|
||||
client: DropboxClient,
|
||||
path: str,
|
||||
include_subfolders: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Get all indexable files in a folder, optionally recursing into subfolders."""
|
||||
try:
|
||||
items, error = await client.list_folder(path)
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
files: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if is_folder(item):
|
||||
if include_subfolders:
|
||||
sub_files, sub_error = await get_files_in_folder(
|
||||
client, item.get("path_lower", ""), include_subfolders=True
|
||||
)
|
||||
if sub_error:
|
||||
logger.warning(
|
||||
f"Error recursing into folder {item.get('name')}: {sub_error}"
|
||||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files in folder: {e!s}", exc_info=True)
|
||||
return [], f"Error getting files in folder: {e!s}"
|
||||
|
||||
|
||||
async def get_file_by_path(
|
||||
client: DropboxClient,
|
||||
path: str,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Get file metadata by path."""
|
||||
try:
|
||||
item, error = await client.get_metadata(path)
|
||||
if error:
|
||||
return None, error
|
||||
if not item:
|
||||
return None, f"File not found: {path}"
|
||||
return item, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file by path: {e!s}", exc_info=True)
|
||||
return None, f"Error getting file by path: {e!s}"
|
||||
|
|
@ -310,7 +310,7 @@ class GoogleGmailConnector:
|
|||
Fetch recent messages from Gmail within specified date range.
|
||||
Args:
|
||||
max_results: Maximum number of messages to fetch (default: 50)
|
||||
start_date: Start date in YYYY-MM-DD format (default: 30 days ago)
|
||||
start_date: Start date in YYYY-MM-DD format (default: 3 days ago)
|
||||
end_date: End date in YYYY-MM-DD format (default: today)
|
||||
Returns:
|
||||
Tuple containing (messages list with details, error message or None)
|
||||
|
|
@ -334,8 +334,8 @@ class GoogleGmailConnector:
|
|||
start_query = start_dt.strftime("%Y/%m/%d")
|
||||
query_parts.append(f"after:{start_query}")
|
||||
else:
|
||||
# Default to 30 days ago
|
||||
cutoff_date = datetime.now() - timedelta(days=30)
|
||||
# Default to 3 days ago
|
||||
cutoff_date = datetime.now() - timedelta(days=3)
|
||||
date_query = cutoff_date.strftime("%Y/%m/%d")
|
||||
query_parts.append(f"after:{date_query}")
|
||||
|
||||
|
|
|
|||
13
surfsense_backend/app/connectors/onedrive/__init__.py
Normal file
13
surfsense_backend/app/connectors/onedrive/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Microsoft OneDrive Connector Module."""
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .content_extractor import download_and_extract_content
|
||||
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
|
||||
|
||||
__all__ = [
|
||||
"OneDriveClient",
|
||||
"download_and_extract_content",
|
||||
"get_file_by_id",
|
||||
"get_files_in_folder",
|
||||
"list_folder_contents",
|
||||
]
|
||||
283
surfsense_backend/app/connectors/onedrive/client.py
Normal file
283
surfsense_backend/app/connectors/onedrive/client.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Microsoft OneDrive API client using Microsoft Graph API v1.0."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
||||
class OneDriveClient:
|
||||
"""Client for Microsoft OneDrive via the Graph API."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
"""Get a valid access token, refreshing if needed."""
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
cfg = connector.config or {}
|
||||
is_encrypted = cfg.get("_token_encrypted", False)
|
||||
token_encryption = (
|
||||
TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None
|
||||
)
|
||||
|
||||
access_token = cfg.get("access_token", "")
|
||||
refresh_token = cfg.get("refresh_token")
|
||||
|
||||
if is_encrypted and token_encryption:
|
||||
if access_token:
|
||||
access_token = token_encryption.decrypt_token(access_token)
|
||||
if refresh_token:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
|
||||
expires_at_str = cfg.get("expires_at")
|
||||
is_expired = False
|
||||
if expires_at_str:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
is_expired = expires_at <= datetime.now(UTC)
|
||||
|
||||
if not is_expired and access_token:
|
||||
return access_token
|
||||
|
||||
if not refresh_token:
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("OneDrive token expired and no refresh token available")
|
||||
|
||||
token_data = await self._refresh_token(refresh_token)
|
||||
|
||||
new_access = token_data["access_token"]
|
||||
new_refresh = token_data.get("refresh_token", refresh_token)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
new_expires_at = None
|
||||
if expires_in:
|
||||
new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
if token_encryption:
|
||||
cfg["access_token"] = token_encryption.encrypt_token(new_access)
|
||||
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh)
|
||||
else:
|
||||
cfg["access_token"] = new_access
|
||||
cfg["refresh_token"] = new_refresh
|
||||
|
||||
cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["_token_encrypted"] = bool(token_encryption)
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
|
||||
return new_access
|
||||
|
||||
async def _refresh_token(self, refresh_token: str) -> dict:
|
||||
data = {
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
error_json = resp.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise ValueError(f"OneDrive token refresh failed: {error_detail}")
|
||||
return resp.json()
|
||||
|
||||
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
||||
"""Make an authenticated request to the Graph API."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(
|
||||
method,
|
||||
f"{GRAPH_API_BASE}{path}",
|
||||
headers=headers,
|
||||
timeout=60.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector:
|
||||
cfg = connector.config or {}
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("OneDrive authentication expired (401)")
|
||||
|
||||
return resp
|
||||
|
||||
async def list_children(
|
||||
self, item_id: str = "root"
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
all_items: list[dict[str, Any]] = []
|
||||
url = f"/me/drive/items/{item_id}/children"
|
||||
params: dict[str, Any] = {
|
||||
"$top": 200,
|
||||
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl,remoteItem,package",
|
||||
}
|
||||
while url:
|
||||
resp = await self._request("GET", url, params=params)
|
||||
if resp.status_code != 200:
|
||||
return [], f"Failed to list children: {resp.status_code} - {resp.text}"
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("value", []))
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
url = next_link.replace(GRAPH_API_BASE, "")
|
||||
params = {}
|
||||
else:
|
||||
url = ""
|
||||
return all_items, None
|
||||
|
||||
async def get_item_metadata(
|
||||
self, item_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
resp = await self._request(
|
||||
"GET",
|
||||
f"/me/drive/items/{item_id}",
|
||||
params={
|
||||
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl"
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get item: {resp.status_code} - {resp.text}"
|
||||
return resp.json(), None
|
||||
|
||||
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
|
||||
token = await self._get_valid_token()
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=120.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Download failed: {resp.status_code}"
|
||||
return resp.content, None
|
||||
|
||||
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
|
||||
"""Stream file content to disk. Returns error message on failure."""
|
||||
token = await self._get_valid_token()
|
||||
async with (
|
||||
httpx.AsyncClient(follow_redirects=True) as client,
|
||||
client.stream(
|
||||
"GET",
|
||||
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=120.0,
|
||||
) as resp,
|
||||
):
|
||||
if resp.status_code != 200:
|
||||
return f"Download failed: {resp.status_code}"
|
||||
with open(dest_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024):
|
||||
f.write(chunk)
|
||||
return None
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
name: str,
|
||||
parent_id: str | None = None,
|
||||
content: str | bytes | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create (upload) a file in OneDrive."""
|
||||
folder_path = f"/me/drive/items/{parent_id or 'root'}"
|
||||
if isinstance(content, bytes):
|
||||
body = content
|
||||
else:
|
||||
body = (content or "").encode("utf-8")
|
||||
resp = await self._request(
|
||||
"PUT",
|
||||
f"{folder_path}:/{name}:/content",
|
||||
content=body,
|
||||
headers={"Content-Type": mime_type or "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code not in (200, 201):
|
||||
raise ValueError(f"File creation failed: {resp.status_code} - {resp.text}")
|
||||
return resp.json()
|
||||
|
||||
async def trash_file(self, item_id: str) -> bool:
|
||||
"""Delete (move to recycle bin) a OneDrive item."""
|
||||
resp = await self._request("DELETE", f"/me/drive/items/{item_id}")
|
||||
if resp.status_code not in (200, 204):
|
||||
raise ValueError(f"Trash failed: {resp.status_code} - {resp.text}")
|
||||
return True
|
||||
|
||||
async def get_delta(
|
||||
self, folder_id: str | None = None, delta_link: str | None = None
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
"""Get delta changes. Returns (changes, new_delta_link, error)."""
|
||||
all_changes: list[dict[str, Any]] = []
|
||||
if delta_link:
|
||||
url = delta_link.replace(GRAPH_API_BASE, "")
|
||||
elif folder_id:
|
||||
url = f"/me/drive/items/{folder_id}/delta"
|
||||
else:
|
||||
url = "/me/drive/root/delta"
|
||||
|
||||
params: dict[str, Any] = {"$top": 200}
|
||||
while url:
|
||||
resp = await self._request("GET", url, params=params)
|
||||
if resp.status_code != 200:
|
||||
return [], None, f"Delta failed: {resp.status_code} - {resp.text}"
|
||||
data = resp.json()
|
||||
all_changes.extend(data.get("value", []))
|
||||
next_link = data.get("@odata.nextLink")
|
||||
new_delta_link = data.get("@odata.deltaLink")
|
||||
if next_link:
|
||||
url = next_link.replace(GRAPH_API_BASE, "")
|
||||
params = {}
|
||||
else:
|
||||
url = ""
|
||||
return all_changes, new_delta_link, None
|
||||
181
surfsense_backend/app/connectors/onedrive/content_extractor.py
Normal file
181
surfsense_backend/app/connectors/onedrive/content_extractor.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""Content extraction for OneDrive files.
|
||||
|
||||
Reuses the same ETL parsing logic as Google Drive since file parsing is
|
||||
extension-based, not provider-specific.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .file_types import get_extension_from_mime, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_and_extract_content(
|
||||
client: OneDriveClient,
|
||||
file: dict[str, Any],
|
||||
) -> tuple[str | None, dict[str, Any], str | None]:
|
||||
"""Download a OneDrive file and extract its content as markdown.
|
||||
|
||||
Returns (markdown_content, onedrive_metadata, error_message).
|
||||
"""
|
||||
item_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if should_skip_file(file):
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
file_info = file.get("file", {})
|
||||
mime_type = file_info.get("mimeType", "")
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"onedrive_file_id": item_id,
|
||||
"onedrive_file_name": file_name,
|
||||
"onedrive_mime_type": mime_type,
|
||||
"source_connector": "onedrive",
|
||||
}
|
||||
if "lastModifiedDateTime" in file:
|
||||
metadata["modified_time"] = file["lastModifiedDateTime"]
|
||||
if "createdDateTime" in file:
|
||||
metadata["created_time"] = file["createdDateTime"]
|
||||
if "size" in file:
|
||||
metadata["file_size"] = file["size"]
|
||||
if "webUrl" in file:
|
||||
metadata["web_url"] = file["webUrl"]
|
||||
file_hashes = file_info.get("hashes", {})
|
||||
if file_hashes.get("sha256Hash"):
|
||||
metadata["sha256_hash"] = file_hashes["sha256Hash"]
|
||||
elif file_hashes.get("quickXorHash"):
|
||||
metadata["quick_xor_hash"] = file_hashes["quickXorHash"]
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
extension = (
|
||||
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||
temp_file_path = tmp.name
|
||||
|
||||
error = await client.download_file_to_disk(item_id, temp_file_path)
|
||||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
return markdown, metadata, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
|
||||
return None, metadata, str(e)
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service.
|
||||
|
||||
Same logic as Google Drive -- file parsing is extension-based.
|
||||
"""
|
||||
lower = filename.lower()
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(
|
||||
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path, estimated_pages=50
|
||||
)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(
|
||||
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
50
surfsense_backend/app/connectors/onedrive/file_types.py
Normal file
50
surfsense_backend/app/connectors/onedrive/file_types.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""File type handlers for Microsoft OneDrive."""
|
||||
|
||||
ONEDRIVE_FOLDER_FACET = "folder"
|
||||
ONENOTE_MIME = "application/msonenote"
|
||||
|
||||
SKIP_MIME_TYPES = frozenset(
|
||||
{
|
||||
ONENOTE_MIME,
|
||||
"application/vnd.ms-onenotesection",
|
||||
"application/vnd.ms-onenotenotebook",
|
||||
}
|
||||
)
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
}
|
||||
|
||||
|
||||
def get_extension_from_mime(mime_type: str) -> str | None:
|
||||
return MIME_TO_EXTENSION.get(mime_type)
|
||||
|
||||
|
||||
def is_folder(item: dict) -> bool:
|
||||
return ONEDRIVE_FOLDER_FACET in item
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
"""Skip folders, OneNote files, remote items (shared links), and packages."""
|
||||
if is_folder(item):
|
||||
return True
|
||||
if "remoteItem" in item:
|
||||
return True
|
||||
if "package" in item:
|
||||
return True
|
||||
mime = item.get("file", {}).get("mimeType", "")
|
||||
return mime in SKIP_MIME_TYPES
|
||||
99
surfsense_backend/app/connectors/onedrive/folder_manager.py
Normal file
99
surfsense_backend/app/connectors/onedrive/folder_manager.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Folder management for Microsoft OneDrive."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .file_types import is_folder, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def list_folder_contents(
|
||||
client: OneDriveClient,
|
||||
parent_id: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""List folders and files in a OneDrive folder.
|
||||
|
||||
Returns (items list with folders first, error message).
|
||||
"""
|
||||
try:
|
||||
items, error = await client.list_children(parent_id or "root")
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = is_folder(item)
|
||||
if item["isFolder"]:
|
||||
item.setdefault("mimeType", "application/vnd.ms-folder")
|
||||
else:
|
||||
item.setdefault(
|
||||
"mimeType",
|
||||
item.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||
)
|
||||
|
||||
items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower()))
|
||||
|
||||
folder_count = sum(1 for item in items if item["isFolder"])
|
||||
file_count = len(items) - folder_count
|
||||
logger.info(
|
||||
f"Listed {len(items)} items ({folder_count} folders, {file_count} files) "
|
||||
+ (f"in folder {parent_id}" if parent_id else "in root")
|
||||
)
|
||||
return items, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing folder contents: {e!s}", exc_info=True)
|
||||
return [], f"Error listing folder contents: {e!s}"
|
||||
|
||||
|
||||
async def get_files_in_folder(
|
||||
client: OneDriveClient,
|
||||
folder_id: str,
|
||||
include_subfolders: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Get all indexable files in a folder, optionally recursing into subfolders."""
|
||||
try:
|
||||
items, error = await client.list_children(folder_id)
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
files: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if is_folder(item):
|
||||
if include_subfolders:
|
||||
sub_files, sub_error = await get_files_in_folder(
|
||||
client, item["id"], include_subfolders=True
|
||||
)
|
||||
if sub_error:
|
||||
logger.warning(
|
||||
f"Error recursing into folder {item.get('name')}: {sub_error}"
|
||||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files in folder: {e!s}", exc_info=True)
|
||||
return [], f"Error getting files in folder: {e!s}"
|
||||
|
||||
|
||||
async def get_file_by_id(
|
||||
client: OneDriveClient,
|
||||
file_id: str,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Get file metadata by ID."""
|
||||
try:
|
||||
item, error = await client.get_item_metadata(file_id)
|
||||
if error:
|
||||
return None, error
|
||||
if not item:
|
||||
return None, f"File not found: {file_id}"
|
||||
return item, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file by ID: {e!s}", exc_info=True)
|
||||
return None, f"Error getting file by ID: {e!s}"
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -40,6 +41,7 @@ class DocumentType(StrEnum):
|
|||
FILE = "FILE"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
TEAMS_CONNECTOR = "TEAMS_CONNECTOR"
|
||||
ONEDRIVE_FILE = "ONEDRIVE_FILE"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
YOUTUBE_VIDEO = "YOUTUBE_VIDEO"
|
||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
|
|
@ -58,6 +60,7 @@ class DocumentType(StrEnum):
|
|||
CIRCLEBACK = "CIRCLEBACK"
|
||||
OBSIDIAN_CONNECTOR = "OBSIDIAN_CONNECTOR"
|
||||
NOTE = "NOTE"
|
||||
DROPBOX_FILE = "DROPBOX_FILE"
|
||||
COMPOSIO_GOOGLE_DRIVE_CONNECTOR = "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"
|
||||
COMPOSIO_GMAIL_CONNECTOR = "COMPOSIO_GMAIL_CONNECTOR"
|
||||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
|
@ -81,6 +84,7 @@ class SearchSourceConnectorType(StrEnum):
|
|||
BAIDU_SEARCH_API = "BAIDU_SEARCH_API" # Baidu AI Search API for Chinese web search
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
TEAMS_CONNECTOR = "TEAMS_CONNECTOR"
|
||||
ONEDRIVE_CONNECTOR = "ONEDRIVE_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||
|
|
@ -101,6 +105,7 @@ class SearchSourceConnectorType(StrEnum):
|
|||
"OBSIDIAN_CONNECTOR" # Self-hosted only - Local Obsidian vault indexing
|
||||
)
|
||||
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
|
||||
DROPBOX_CONNECTOR = "DROPBOX_CONNECTOR"
|
||||
COMPOSIO_GOOGLE_DRIVE_CONNECTOR = "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"
|
||||
COMPOSIO_GMAIL_CONNECTOR = "COMPOSIO_GMAIL_CONNECTOR"
|
||||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
|
@ -288,6 +293,12 @@ class IncentiveTaskType(StrEnum):
|
|||
# REFER_FRIEND = "REFER_FRIEND"
|
||||
|
||||
|
||||
class PagePurchaseStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Centralized configuration for incentive tasks
|
||||
# This makes it easy to add new tasks without changing code in multiple places
|
||||
INCENTIVE_TASKS_CONFIG = {
|
||||
|
|
@ -1639,6 +1650,39 @@ class UserIncentiveTask(BaseModel, TimestampMixin):
|
|||
user = relationship("User", back_populates="incentive_tasks")
|
||||
|
||||
|
||||
class PagePurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional page credits."""
|
||||
|
||||
__tablename__ = "page_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
stripe_checkout_session_id = Column(
|
||||
String(255), nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_payment_intent_id = Column(String(255), nullable=True, index=True)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
pages_granted = Column(Integer, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(PagePurchaseStatus),
|
||||
nullable=False,
|
||||
default=PagePurchaseStatus.PENDING,
|
||||
server_default=text("'PENDING'::pagepurchasestatus"),
|
||||
index=True,
|
||||
)
|
||||
completed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="page_purchases")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Custom roles that can be defined per search space.
|
||||
|
|
@ -1773,6 +1817,47 @@ class SearchSpaceInvite(BaseModel, TimestampMixin):
|
|||
)
|
||||
|
||||
|
||||
class PromptMode(StrEnum):
|
||||
transform = "transform"
|
||||
explore = "explore"
|
||||
|
||||
|
||||
class Prompt(BaseModel, TimestampMixin):
|
||||
__tablename__ = "prompts"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id",
|
||||
"default_prompt_slug",
|
||||
name="uq_prompt_user_default_slug",
|
||||
),
|
||||
)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
default_prompt_slug = Column(String(100), nullable=True, index=True)
|
||||
name = Column(String(200), nullable=False)
|
||||
prompt = Column(Text, nullable=False)
|
||||
mode = Column(
|
||||
SQLAlchemyEnum(PromptMode, name="prompt_mode", create_type=False),
|
||||
nullable=False,
|
||||
)
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
is_public = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
user = relationship("User")
|
||||
search_space = relationship("SearchSpace")
|
||||
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
|
|
@ -1865,6 +1950,11 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
page_purchases = relationship(
|
||||
"PagePurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
@ -1974,6 +2064,11 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
page_purchases = relationship(
|
||||
"PagePurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
|
|||
74
surfsense_backend/app/prompts/system_defaults.py
Normal file
74
surfsense_backend/app/prompts/system_defaults.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||
{
|
||||
"slug": "fix-grammar",
|
||||
"version": 1,
|
||||
"name": "Fix grammar",
|
||||
"prompt": (
|
||||
"Fix the grammar and spelling in the following text."
|
||||
" Return only the corrected text, nothing else.\n\n{selection}"
|
||||
),
|
||||
"mode": "transform",
|
||||
},
|
||||
{
|
||||
"slug": "make-shorter",
|
||||
"version": 1,
|
||||
"name": "Make shorter",
|
||||
"prompt": (
|
||||
"Make the following text more concise while preserving its meaning."
|
||||
" Return only the shortened text, nothing else.\n\n{selection}"
|
||||
),
|
||||
"mode": "transform",
|
||||
},
|
||||
{
|
||||
"slug": "translate",
|
||||
"version": 1,
|
||||
"name": "Translate",
|
||||
"prompt": (
|
||||
"Translate the following text to English."
|
||||
" If it is already in English, translate it to French."
|
||||
" Return only the translation, nothing else.\n\n{selection}"
|
||||
),
|
||||
"mode": "transform",
|
||||
},
|
||||
{
|
||||
"slug": "rewrite",
|
||||
"version": 1,
|
||||
"name": "Rewrite",
|
||||
"prompt": (
|
||||
"Rewrite the following text to improve clarity and readability."
|
||||
" Return only the rewritten text, nothing else.\n\n{selection}"
|
||||
),
|
||||
"mode": "transform",
|
||||
},
|
||||
{
|
||||
"slug": "summarize",
|
||||
"version": 1,
|
||||
"name": "Summarize",
|
||||
"prompt": (
|
||||
"Summarize the following text concisely."
|
||||
" Return only the summary, nothing else.\n\n{selection}"
|
||||
),
|
||||
"mode": "transform",
|
||||
},
|
||||
{
|
||||
"slug": "explain",
|
||||
"version": 1,
|
||||
"name": "Explain",
|
||||
"prompt": "Explain the following text in simple terms:\n\n{selection}",
|
||||
"mode": "explore",
|
||||
},
|
||||
{
|
||||
"slug": "ask-knowledge-base",
|
||||
"version": 1,
|
||||
"name": "Ask my knowledge base",
|
||||
"prompt": "Search my knowledge base for information related to:\n\n{selection}",
|
||||
"mode": "explore",
|
||||
},
|
||||
{
|
||||
"slug": "look-up-web",
|
||||
"version": 1,
|
||||
"name": "Look up on the web",
|
||||
"prompt": "Search the web for information about:\n\n{selection}",
|
||||
"mode": "explore",
|
||||
},
|
||||
]
|
||||
|
|
@ -10,6 +10,7 @@ from .composio_routes import router as composio_router
|
|||
from .confluence_add_connector_route import router as confluence_add_connector_router
|
||||
from .discord_add_connector_route import router as discord_add_connector_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .dropbox_add_connector_route import router as dropbox_add_connector_router
|
||||
from .editor_routes import router as editor_router
|
||||
from .folders_routes import router as folders_router
|
||||
from .google_calendar_add_connector_route import (
|
||||
|
|
@ -33,7 +34,9 @@ from .new_llm_config_routes import router as new_llm_config_router
|
|||
from .notes_routes import router as notes_router
|
||||
from .notifications_routes import router as notifications_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .onedrive_add_connector_route import router as onedrive_add_connector_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .prompts_routes import router as prompts_router
|
||||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
from .reports_routes import router as reports_router
|
||||
|
|
@ -41,6 +44,7 @@ from .sandbox_routes import router as sandbox_router
|
|||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
from .stripe_routes import router as stripe_router
|
||||
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
|
|
@ -73,10 +77,12 @@ router.include_router(luma_add_connector_router)
|
|||
router.include_router(notion_add_connector_router)
|
||||
router.include_router(slack_add_connector_router)
|
||||
router.include_router(teams_add_connector_router)
|
||||
router.include_router(onedrive_add_connector_router)
|
||||
router.include_router(discord_add_connector_router)
|
||||
router.include_router(jira_add_connector_router)
|
||||
router.include_router(confluence_add_connector_router)
|
||||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_list_router) # Dynamic LLM model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
|
|
@ -86,4 +92,6 @@ router.include_router(notifications_router) # Notifications with Zero sync
|
|||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
|
|
|
|||
567
surfsense_backend/app/routes/dropbox_add_connector_route.py
Normal file
567
surfsense_backend/app/routes/dropbox_add_connector_route.py
Normal file
|
|
@ -0,0 +1,567 @@
|
|||
"""
|
||||
Dropbox Connector OAuth Routes.
|
||||
|
||||
Endpoints:
|
||||
- GET /auth/dropbox/connector/add - Initiate OAuth
|
||||
- GET /auth/dropbox/connector/callback - Handle OAuth callback
|
||||
- GET /auth/dropbox/connector/reauth - Re-authenticate existing connector
|
||||
- GET /connectors/{connector_id}/dropbox/folders - List folder contents
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.dropbox import DropboxClient, list_folder_contents
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
AUTHORIZATION_URL = "https://www.dropbox.com/oauth2/authorize"
|
||||
TOKEN_URL = "https://api.dropboxapi.com/oauth2/token"
|
||||
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
@router.get("/auth/dropbox/connector/add")
|
||||
async def connect_dropbox(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""Initiate Dropbox OAuth flow."""
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
if not config.DROPBOX_APP_KEY:
|
||||
raise HTTPException(status_code=500, detail="Dropbox OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.DROPBOX_APP_KEY,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.DROPBOX_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"token_access_type": "offline",
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated Dropbox OAuth URL for user %s, space %s", user.id, space_id
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate Dropbox OAuth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Dropbox OAuth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/dropbox/connector/reauth")
|
||||
async def reauth_dropbox(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Re-authenticate an existing Dropbox connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Dropbox connector not found or access denied"
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.DROPBOX_APP_KEY,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.DROPBOX_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"token_access_type": "offline",
|
||||
"force_reapprove": "true",
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating Dropbox re-auth for user %s, connector %s",
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate Dropbox re-auth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Dropbox re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/dropbox/connector/callback")
|
||||
async def dropbox_callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
error_description: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Handle Dropbox OAuth callback."""
|
||||
try:
|
||||
if error:
|
||||
error_msg = error_description or error
|
||||
logger.warning("Dropbox OAuth error: %s", error_msg)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=dropbox_oauth_denied"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=dropbox_oauth_denied"
|
||||
)
|
||||
|
||||
if not code or not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required OAuth parameters"
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data["space_id"]
|
||||
user_id = UUID(data["user_id"])
|
||||
except (HTTPException, ValueError, KeyError) as e:
|
||||
logger.error("Invalid OAuth state: %s", str(e))
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state"
|
||||
)
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
token_data = {
|
||||
"client_id": config.DROPBOX_APP_KEY,
|
||||
"client_secret": config.DROPBOX_APP_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": config.DROPBOX_REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=token_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Dropbox"
|
||||
)
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
user_info: dict = {}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.post(
|
||||
"https://api.dropboxapi.com/2/users/get_current_account",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
content=b"null",
|
||||
timeout=30.0,
|
||||
)
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
user_info = {
|
||||
"user_email": user_data.get("email"),
|
||||
"user_name": user_data.get("name", {}).get("display_name"),
|
||||
"account_id": user_data.get("account_id"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch user info from Dropbox: %s", str(e))
|
||||
|
||||
connector_config = {
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"refresh_token": token_encryption.encrypt_token(refresh_token)
|
||||
if refresh_token
|
||||
else None,
|
||||
"token_type": token_json.get("token_type", "bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"user_email": user_info.get("user_email"),
|
||||
"user_name": user_info.get("user_name"),
|
||||
"account_id": user_info.get("account_id"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
existing_cursor = db_connector.config.get("cursor")
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"cursor": existing_cursor,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated Dropbox connector %s for user %s",
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=DROPBOX_CONNECTOR&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.DROPBOX_CONNECTOR, connector_config
|
||||
)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
space_id,
|
||||
user_id,
|
||||
connector_identifier,
|
||||
)
|
||||
if is_duplicate:
|
||||
logger.warning(
|
||||
"Duplicate Dropbox connector for user %s, space %s", user_id, space_id
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=DROPBOX_CONNECTOR"
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
space_id,
|
||||
user_id,
|
||||
connector_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
is_indexable=True,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
session.add(new_connector)
|
||||
await session.commit()
|
||||
await session.refresh(new_connector)
|
||||
logger.info(
|
||||
"Successfully created Dropbox connector %s for user %s",
|
||||
new_connector.id,
|
||||
user_id,
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=DROPBOX_CONNECTOR&connectorId={new_connector.id}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
logger.error(
|
||||
"Database integrity error creating Dropbox connector: %s", str(e)
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (IntegrityError, ValueError) as e:
|
||||
logger.error("Dropbox OAuth callback error: %s", str(e), exc_info=True)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=dropbox_auth_error"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/dropbox/folders")
|
||||
async def list_dropbox_folders(
|
||||
connector_id: int,
|
||||
parent_path: str = "",
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List folders and files in user's Dropbox."""
|
||||
connector = None
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Dropbox connector not found or access denied"
|
||||
)
|
||||
|
||||
dropbox_client = DropboxClient(session, connector_id)
|
||||
items, error = await list_folder_contents(dropbox_client, path=parent_path)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "authentication expired" in error_lower
|
||||
or "expired_access_token" in error_lower
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Dropbox authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
return {"items": items}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error listing Dropbox contents: %s", str(e), exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if "401" in str(e) or "authentication expired" in error_lower:
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Dropbox authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Dropbox contents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_dropbox_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
"""Refresh Dropbox OAuth tokens."""
|
||||
logger.info("Refreshing Dropbox OAuth tokens for connector %s", connector.id)
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
refresh_token = connector.config.get("refresh_token")
|
||||
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No refresh token available for connector {connector.id}",
|
||||
)
|
||||
|
||||
refresh_data = {
|
||||
"client_id": config.DROPBOX_APP_KEY,
|
||||
"client_secret": config.DROPBOX_APP_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=refresh_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Dropbox authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Dropbox refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
cfg = dict(connector.config)
|
||||
cfg["access_token"] = token_encryption.encrypt_token(access_token)
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
cfg["_token_encrypted"] = True
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Successfully refreshed Dropbox tokens for connector %s", connector.id)
|
||||
return connector
|
||||
|
|
@ -117,8 +117,10 @@ async def complete_task(
|
|||
)
|
||||
session.add(new_task)
|
||||
|
||||
# Update user's pages_limit
|
||||
user.pages_limit += pages_reward
|
||||
# pages_used can exceed pages_limit when a document's final page count is
|
||||
# determined after processing. Base the new limit on the higher of the two
|
||||
# so the rewarded pages are fully usable above the current high-water mark.
|
||||
user.pages_limit = max(user.pages_used, user.pages_limit) + pages_reward
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
|
|
|||
581
surfsense_backend/app/routes/onedrive_add_connector_route.py
Normal file
581
surfsense_backend/app/routes/onedrive_add_connector_route.py
Normal file
|
|
@ -0,0 +1,581 @@
|
|||
"""
|
||||
Microsoft OneDrive Connector OAuth Routes.
|
||||
|
||||
Endpoints:
|
||||
- GET /auth/onedrive/connector/add - Initiate OAuth
|
||||
- GET /auth/onedrive/connector/callback - Handle OAuth callback
|
||||
- GET /auth/onedrive/connector/reauth - Re-authenticate existing connector
|
||||
- GET /connectors/{connector_id}/onedrive/folders - List folder contents
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.onedrive import OneDriveClient, list_folder_contents
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
AUTHORIZATION_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"User.Read",
|
||||
"Files.Read.All",
|
||||
"Files.ReadWrite.All",
|
||||
]
|
||||
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/add")
|
||||
async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""Initiate OneDrive OAuth flow."""
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
if not config.MICROSOFT_CLIENT_ID:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Microsoft OneDrive OAuth not configured."
|
||||
)
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"response_mode": "query",
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated OneDrive OAuth URL for user %s, space %s", user.id, space_id
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate OneDrive OAuth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate OneDrive OAuth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/reauth")
|
||||
async def reauth_onedrive(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Re-authenticate an existing OneDrive connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="OneDrive connector not found or access denied"
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"response_mode": "query",
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
"prompt": "consent",
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating OneDrive re-auth for user %s, connector %s",
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate OneDrive re-auth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate OneDrive re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/callback")
|
||||
async def onedrive_callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
error_description: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Handle OneDrive OAuth callback."""
|
||||
try:
|
||||
if error:
|
||||
error_msg = error_description or error
|
||||
logger.warning("OneDrive OAuth error: %s", error_msg)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=onedrive_oauth_denied"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_oauth_denied"
|
||||
)
|
||||
|
||||
if not code or not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required OAuth parameters"
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data["space_id"]
|
||||
user_id = UUID(data["user_id"])
|
||||
except (HTTPException, ValueError, KeyError) as e:
|
||||
logger.error("Invalid OAuth state: %s", str(e))
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state"
|
||||
)
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
token_data = {
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=token_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Microsoft"
|
||||
)
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
user_info: dict = {}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(
|
||||
"https://graph.microsoft.com/v1.0/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
user_info = {
|
||||
"user_email": user_data.get("mail")
|
||||
or user_data.get("userPrincipalName"),
|
||||
"user_name": user_data.get("displayName"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch user info from Graph: %s", str(e))
|
||||
|
||||
connector_config = {
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"refresh_token": token_encryption.encrypt_token(refresh_token)
|
||||
if refresh_token
|
||||
else None,
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"user_email": user_info.get("user_email"),
|
||||
"user_name": user_info.get("user_name"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
existing_delta_link = db_connector.config.get("delta_link")
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"delta_link": existing_delta_link,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated OneDrive connector %s for user %s",
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# New connector -- check for duplicates
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR, connector_config
|
||||
)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
space_id,
|
||||
user_id,
|
||||
connector_identifier,
|
||||
)
|
||||
if is_duplicate:
|
||||
logger.warning(
|
||||
"Duplicate OneDrive connector for user %s, space %s", user_id, space_id
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=ONEDRIVE_CONNECTOR"
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
space_id,
|
||||
user_id,
|
||||
connector_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
session.add(new_connector)
|
||||
await session.commit()
|
||||
await session.refresh(new_connector)
|
||||
logger.info(
|
||||
"Successfully created OneDrive connector %s for user %s",
|
||||
new_connector.id,
|
||||
user_id,
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={new_connector.id}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
logger.error(
|
||||
"Database integrity error creating OneDrive connector: %s", str(e)
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (IntegrityError, ValueError) as e:
|
||||
logger.error("OneDrive OAuth callback error: %s", str(e), exc_info=True)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_auth_error"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/onedrive/folders")
|
||||
async def list_onedrive_folders(
|
||||
connector_id: int,
|
||||
parent_id: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List folders and files in user's OneDrive."""
|
||||
connector = None
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="OneDrive connector not found or access denied"
|
||||
)
|
||||
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
items, error = await list_folder_contents(onedrive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "authentication expired" in error_lower
|
||||
or "invalid_grant" in error_lower
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="OneDrive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
return {"items": items}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error listing OneDrive contents: %s", str(e), exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if "401" in str(e) or "authentication expired" in error_lower:
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="OneDrive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list OneDrive contents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_onedrive_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
"""Refresh OneDrive OAuth tokens."""
|
||||
logger.info("Refreshing OneDrive OAuth tokens for connector %s", connector.id)
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
refresh_token = connector.config.get("refresh_token")
|
||||
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No refresh token available for connector {connector.id}",
|
||||
)
|
||||
|
||||
refresh_data = {
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": " ".join(SCOPES),
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=refresh_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="OneDrive authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
new_refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Microsoft refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
cfg = dict(connector.config)
|
||||
cfg["access_token"] = token_encryption.encrypt_token(access_token)
|
||||
if new_refresh_token:
|
||||
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh_token)
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
cfg["scope"] = token_json.get("scope")
|
||||
cfg["_token_encrypted"] = True
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Successfully refreshed OneDrive tokens for connector %s", connector.id)
|
||||
return connector
|
||||
166
surfsense_backend/app/routes/prompts_routes.py
Normal file
166
surfsense_backend/app/routes/prompts_routes.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import Prompt, SearchSpaceMembership, User, get_async_session
|
||||
from app.schemas.prompts import (
|
||||
PromptCreate,
|
||||
PromptRead,
|
||||
PromptUpdate,
|
||||
PublicPromptRead,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
router = APIRouter(tags=["Prompts"])
|
||||
|
||||
|
||||
@router.get("/prompts", response_model=list[PromptRead])
|
||||
async def list_prompts(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
query = select(Prompt).where(Prompt.user_id == user.id)
|
||||
if search_space_id is not None:
|
||||
query = query.where(Prompt.search_space_id == search_space_id)
|
||||
query = query.order_by(Prompt.created_at.desc())
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/prompts", response_model=PromptRead)
|
||||
async def create_prompt(
|
||||
body: PromptCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if body.search_space_id is not None:
|
||||
membership = await session.execute(
|
||||
select(SearchSpaceMembership).where(
|
||||
SearchSpaceMembership.user_id == user.id,
|
||||
SearchSpaceMembership.search_space_id == body.search_space_id,
|
||||
)
|
||||
)
|
||||
if not membership.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You are not a member of this search space",
|
||||
)
|
||||
|
||||
prompt = Prompt(
|
||||
user_id=user.id,
|
||||
search_space_id=body.search_space_id,
|
||||
name=body.name,
|
||||
prompt=body.prompt,
|
||||
mode=body.mode,
|
||||
is_public=body.is_public,
|
||||
)
|
||||
session.add(prompt)
|
||||
await session.commit()
|
||||
await session.refresh(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
|
||||
async def update_prompt(
|
||||
prompt_id: int,
|
||||
body: PromptUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
Prompt.user_id == user.id,
|
||||
)
|
||||
)
|
||||
prompt = result.scalar_one_or_none()
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
content_fields = {"name", "prompt", "mode"}
|
||||
has_content_change = bool(updates.keys() & content_fields)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(prompt, field, value)
|
||||
|
||||
if has_content_change:
|
||||
prompt.version = Prompt.version + 1
|
||||
|
||||
session.add(prompt)
|
||||
await session.commit()
|
||||
await session.refresh(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
@router.delete("/prompts/{prompt_id}")
|
||||
async def delete_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
Prompt.user_id == user.id,
|
||||
)
|
||||
)
|
||||
prompt = result.scalar_one_or_none()
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
await session.delete(prompt)
|
||||
await session.commit()
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
||||
async def list_public_prompts(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Prompt)
|
||||
.options(selectinload(Prompt.user))
|
||||
.where(Prompt.is_public.is_(True), Prompt.user_id != user.id)
|
||||
.order_by(Prompt.created_at.desc())
|
||||
)
|
||||
prompts = result.scalars().all()
|
||||
return [
|
||||
PublicPromptRead(
|
||||
**PromptRead.model_validate(p).model_dump(),
|
||||
author_name=p.user.email if p.user else None,
|
||||
)
|
||||
for p in prompts
|
||||
]
|
||||
|
||||
|
||||
@router.post("/prompts/{prompt_id}/copy", response_model=PromptRead)
|
||||
async def copy_public_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
Prompt.is_public.is_(True),
|
||||
)
|
||||
)
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
copy = Prompt(
|
||||
user_id=user.id,
|
||||
name=source.name,
|
||||
prompt=source.prompt,
|
||||
mode=source.mode,
|
||||
is_public=False,
|
||||
)
|
||||
session.add(copy)
|
||||
await session.commit()
|
||||
await session.refresh(copy)
|
||||
return copy
|
||||
|
|
@ -999,6 +999,100 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "Google Drive indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_onedrive_files_task,
|
||||
)
|
||||
|
||||
if drive_items and drive_items.has_items():
|
||||
logger.info(
|
||||
f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
items_dict = drive_items.model_dump()
|
||||
else:
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="OneDrive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
logger.info(
|
||||
f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
index_onedrive_files_task.delay(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
items_dict,
|
||||
)
|
||||
response_message = "OneDrive indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DROPBOX_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_dropbox_files_task,
|
||||
)
|
||||
|
||||
if drive_items and drive_items.has_items():
|
||||
logger.info(
|
||||
f"Triggering Dropbox indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
items_dict = drive_items.model_dump()
|
||||
else:
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Dropbox indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
logger.info(
|
||||
f"Triggering Dropbox indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
index_dropbox_files_task.delay(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
items_dict,
|
||||
)
|
||||
response_message = "Dropbox indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_discord_messages_task,
|
||||
|
|
@ -2489,6 +2583,222 @@ async def run_google_drive_indexing(
|
|||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
async def run_onedrive_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Runs the OneDrive indexing task for folders and files with notifications."""
|
||||
from uuid import UUID
|
||||
|
||||
notification = None
|
||||
try:
|
||||
from app.tasks.connector_indexers.onedrive_indexer import index_onedrive_files
|
||||
|
||||
connector_result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = connector_result.scalar_one_or_none()
|
||||
|
||||
if connector:
|
||||
notification = await NotificationService.connector_indexing.notify_google_drive_indexing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
connector_id=connector_id,
|
||||
connector_name=connector.name,
|
||||
connector_type=connector.connector_type.value,
|
||||
search_space_id=search_space_id,
|
||||
folder_count=len(items_dict.get("folders", [])),
|
||||
file_count=len(items_dict.get("files", [])),
|
||||
folder_names=[
|
||||
f.get("name", "Unknown") for f in items_dict.get("folders", [])
|
||||
],
|
||||
file_names=[
|
||||
f.get("name", "Unknown") for f in items_dict.get("files", [])
|
||||
],
|
||||
)
|
||||
|
||||
if notification:
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_onedrive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
if error_message:
|
||||
logger.error(
|
||||
f"OneDrive indexing completed with errors for connector {connector_id}: {error_message}"
|
||||
)
|
||||
if _is_auth_error(error_message):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
error_message = (
|
||||
"OneDrive authentication expired. Please re-authenticate."
|
||||
)
|
||||
else:
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
stage="storing",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive indexing successful for connector {connector_id}. Indexed {total_indexed} documents."
|
||||
)
|
||||
await _update_connector_timestamp_by_id(session, connector_id)
|
||||
await session.commit()
|
||||
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Critical error in run_onedrive_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
error_message=str(e),
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
async def run_dropbox_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Runs the Dropbox indexing task for folders and files with notifications."""
|
||||
from uuid import UUID
|
||||
|
||||
notification = None
|
||||
try:
|
||||
from app.tasks.connector_indexers.dropbox_indexer import index_dropbox_files
|
||||
|
||||
connector_result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = connector_result.scalar_one_or_none()
|
||||
|
||||
if connector:
|
||||
notification = await NotificationService.connector_indexing.notify_google_drive_indexing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
connector_id=connector_id,
|
||||
connector_name=connector.name,
|
||||
connector_type=connector.connector_type.value,
|
||||
search_space_id=search_space_id,
|
||||
folder_count=len(items_dict.get("folders", [])),
|
||||
file_count=len(items_dict.get("files", [])),
|
||||
folder_names=[
|
||||
f.get("name", "Unknown") for f in items_dict.get("folders", [])
|
||||
],
|
||||
file_names=[
|
||||
f.get("name", "Unknown") for f in items_dict.get("files", [])
|
||||
],
|
||||
)
|
||||
|
||||
if notification:
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_dropbox_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
if error_message:
|
||||
logger.error(
|
||||
f"Dropbox indexing completed with errors for connector {connector_id}: {error_message}"
|
||||
)
|
||||
if _is_auth_error(error_message):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
error_message = (
|
||||
"Dropbox authentication expired. Please re-authenticate."
|
||||
)
|
||||
else:
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
stage="storing",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dropbox indexing successful for connector {connector_id}. Indexed {total_indexed} documents."
|
||||
)
|
||||
await _update_connector_timestamp_by_id(session, connector_id)
|
||||
await session.commit()
|
||||
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Critical error in run_dropbox_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
error_message=str(e),
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
# Add new helper functions for luma indexing
|
||||
async def run_luma_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
|
|||
371
surfsense_backend/app/routes/stripe_routes.py
Normal file
371
surfsense_backend/app/routes/stripe_routes.py
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
"""Stripe routes for pay-as-you-go page purchases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from stripe import SignatureVerificationError, StripeClient, StripeError
|
||||
|
||||
from app.config import config
|
||||
from app.db import PagePurchase, PagePurchaseStatus, User, get_async_session
|
||||
from app.schemas.stripe import (
|
||||
CreateCheckoutSessionRequest,
|
||||
CreateCheckoutSessionResponse,
|
||||
PagePurchaseHistoryResponse,
|
||||
StripeStatusResponse,
|
||||
StripeWebhookResponse,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/stripe", tags=["stripe"])
|
||||
|
||||
|
||||
def get_stripe_client() -> StripeClient:
|
||||
"""Return a configured Stripe client or raise if Stripe is disabled."""
|
||||
if not config.STRIPE_SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Stripe checkout is not configured.",
|
||||
)
|
||||
return StripeClient(config.STRIPE_SECRET_KEY)
|
||||
|
||||
|
||||
def _ensure_page_buying_enabled() -> None:
|
||||
if not config.STRIPE_PAGE_BUYING_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Page purchases are temporarily unavailable.",
|
||||
)
|
||||
|
||||
|
||||
def _get_checkout_urls(search_space_id: int) -> tuple[str, str]:
|
||||
if not config.NEXT_FRONTEND_URL:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="NEXT_FRONTEND_URL is not configured.",
|
||||
)
|
||||
|
||||
base_url = config.NEXT_FRONTEND_URL.rstrip("/")
|
||||
success_url = f"{base_url}/dashboard/{search_space_id}/purchase-success"
|
||||
cancel_url = f"{base_url}/dashboard/{search_space_id}/purchase-cancel"
|
||||
return success_url, cancel_url
|
||||
|
||||
|
||||
def _get_required_stripe_price_id() -> str:
|
||||
if not config.STRIPE_PRICE_ID:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="STRIPE_PRICE_ID is not configured.",
|
||||
)
|
||||
return config.STRIPE_PRICE_ID
|
||||
|
||||
|
||||
def _normalize_optional_string(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return getattr(value, "id", str(value))
|
||||
|
||||
|
||||
def _get_metadata(checkout_session: Any) -> dict[str, str]:
|
||||
metadata = getattr(checkout_session, "metadata", None) or {}
|
||||
if isinstance(metadata, dict):
|
||||
return {str(key): str(value) for key, value in metadata.items()}
|
||||
return dict(metadata)
|
||||
|
||||
|
||||
async def _get_or_create_purchase_from_checkout_session(
|
||||
db_session: AsyncSession,
|
||||
checkout_session: Any,
|
||||
) -> PagePurchase | None:
|
||||
"""Look up a PagePurchase by checkout session ID (with FOR UPDATE lock).
|
||||
|
||||
If the row doesn't exist yet (e.g. the webhook arrived before the API
|
||||
response committed), create one from the Stripe session metadata.
|
||||
"""
|
||||
checkout_session_id = str(checkout_session.id)
|
||||
purchase = (
|
||||
await db_session.execute(
|
||||
select(PagePurchase)
|
||||
.where(PagePurchase.stripe_checkout_session_id == checkout_session_id)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if purchase is not None:
|
||||
return purchase
|
||||
|
||||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
pages_per_unit = int(metadata.get("pages_per_unit", "0"))
|
||||
|
||||
if not user_id or quantity <= 0 or pages_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping Stripe fulfillment for session %s due to incomplete metadata: %s",
|
||||
checkout_session_id,
|
||||
metadata,
|
||||
)
|
||||
return None
|
||||
|
||||
purchase = PagePurchase(
|
||||
user_id=uuid.UUID(user_id),
|
||||
stripe_checkout_session_id=checkout_session_id,
|
||||
stripe_payment_intent_id=_normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
pages_granted=quantity * pages_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PagePurchaseStatus.PENDING,
|
||||
)
|
||||
db_session.add(purchase)
|
||||
await db_session.flush()
|
||||
return purchase
|
||||
|
||||
|
||||
async def _mark_purchase_failed(
|
||||
db_session: AsyncSession, checkout_session_id: str
|
||||
) -> StripeWebhookResponse:
|
||||
purchase = (
|
||||
await db_session.execute(
|
||||
select(PagePurchase)
|
||||
.where(PagePurchase.stripe_checkout_session_id == checkout_session_id)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if purchase is not None and purchase.status == PagePurchaseStatus.PENDING:
|
||||
purchase.status = PagePurchaseStatus.FAILED
|
||||
await db_session.commit()
|
||||
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
async def _fulfill_completed_purchase(
|
||||
db_session: AsyncSession, checkout_session: Any
|
||||
) -> StripeWebhookResponse:
|
||||
"""Grant pages to the user after a confirmed Stripe payment.
|
||||
|
||||
Uses SELECT ... FOR UPDATE on both the PagePurchase and User rows to
|
||||
prevent double-granting when Stripe retries the webhook concurrently.
|
||||
"""
|
||||
purchase = await _get_or_create_purchase_from_checkout_session(
|
||||
db_session, checkout_session
|
||||
)
|
||||
if purchase is None:
|
||||
return StripeWebhookResponse()
|
||||
|
||||
if purchase.status == PagePurchaseStatus.COMPLETED:
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(User).where(User.id == purchase.user_id).with_for_update(of=User)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if user is None:
|
||||
logger.error(
|
||||
"Skipping Stripe fulfillment for session %s because user %s was not found.",
|
||||
purchase.stripe_checkout_session_id,
|
||||
purchase.user_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
purchase.status = PagePurchaseStatus.COMPLETED
|
||||
purchase.completed_at = datetime.now(UTC)
|
||||
purchase.amount_total = getattr(checkout_session, "amount_total", None)
|
||||
purchase.currency = getattr(checkout_session, "currency", None)
|
||||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
# pages_used can exceed pages_limit when a document's final page count is
|
||||
# determined after processing. Base the new limit on the higher of the two
|
||||
# so the purchased pages are fully usable above the current high-water mark.
|
||||
user.pages_limit = max(user.pages_used, user.pages_limit) + purchase.pages_granted
|
||||
|
||||
await db_session.commit()
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse)
|
||||
async def create_checkout_session(
|
||||
body: CreateCheckoutSessionRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe Checkout Session for buying page packs."""
|
||||
_ensure_page_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_stripe_price_id()
|
||||
success_url, cancel_url = _get_checkout_urls(body.search_space_id)
|
||||
pages_granted = body.quantity * config.STRIPE_PAGES_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
params={
|
||||
"mode": "payment",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"line_items": [
|
||||
{
|
||||
"price": price_id,
|
||||
"quantity": body.quantity,
|
||||
}
|
||||
],
|
||||
"client_reference_id": str(user.id),
|
||||
"customer_email": user.email,
|
||||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"pages_per_unit": str(config.STRIPE_PAGES_PER_UNIT),
|
||||
"purchase_type": "page_packs",
|
||||
},
|
||||
}
|
||||
)
|
||||
except StripeError as exc:
|
||||
logger.exception(
|
||||
"Failed to create Stripe checkout session for user %s", user.id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to create Stripe checkout session.",
|
||||
) from exc
|
||||
|
||||
checkout_url = getattr(checkout_session, "url", None)
|
||||
if not checkout_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Stripe checkout session did not return a URL.",
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
PagePurchase(
|
||||
user_id=user.id,
|
||||
stripe_checkout_session_id=str(checkout_session.id),
|
||||
stripe_payment_intent_id=_normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
pages_granted=pages_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PagePurchaseStatus.PENDING,
|
||||
)
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return CreateCheckoutSessionResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
@router.get("/status", response_model=StripeStatusResponse)
|
||||
async def get_stripe_status() -> StripeStatusResponse:
|
||||
"""Return page-buying availability for frontend feature gating."""
|
||||
return StripeStatusResponse(page_buying_enabled=config.STRIPE_PAGE_BUYING_ENABLED)
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=StripeWebhookResponse)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> StripeWebhookResponse:
|
||||
"""Handle Stripe webhooks and grant purchased pages after payment."""
|
||||
if not config.STRIPE_WEBHOOK_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Stripe webhook handling is not configured.",
|
||||
)
|
||||
|
||||
stripe_client = get_stripe_client()
|
||||
payload = await request.body()
|
||||
signature = request.headers.get("Stripe-Signature")
|
||||
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing Stripe-Signature header.",
|
||||
)
|
||||
|
||||
try:
|
||||
event = stripe_client.construct_event(
|
||||
payload,
|
||||
signature,
|
||||
config.STRIPE_WEBHOOK_SECRET,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe webhook payload.",
|
||||
) from exc
|
||||
except SignatureVerificationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe webhook signature.",
|
||||
) from exc
|
||||
|
||||
if event.type in {
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
}:
|
||||
checkout_session = event.data.object
|
||||
payment_status = getattr(checkout_session, "payment_status", None)
|
||||
|
||||
if event.type == "checkout.session.completed" and payment_status not in {
|
||||
"paid",
|
||||
"no_payment_required",
|
||||
}:
|
||||
logger.info(
|
||||
"Received checkout.session.completed for unpaid session %s; waiting for async success.",
|
||||
checkout_session.id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
return await _fulfill_completed_purchase(db_session, checkout_session)
|
||||
|
||||
if event.type in {
|
||||
"checkout.session.async_payment_failed",
|
||||
"checkout.session.expired",
|
||||
}:
|
||||
checkout_session = event.data.object
|
||||
return await _mark_purchase_failed(db_session, str(checkout_session.id))
|
||||
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
||||
@router.get("/purchases", response_model=PagePurchaseHistoryResponse)
|
||||
async def get_page_purchases(
|
||||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> PagePurchaseHistoryResponse:
|
||||
"""Return the authenticated user's page-purchase history."""
|
||||
limit = min(limit, 100)
|
||||
purchases = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(PagePurchase)
|
||||
.where(PagePurchase.user_id == user.id)
|
||||
.order_by(PagePurchase.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return PagePurchaseHistoryResponse(purchases=purchases)
|
||||
|
|
@ -88,7 +88,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user)
|
|||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
if not config.TEAMS_CLIENT_ID:
|
||||
if not config.MICROSOFT_CLIENT_ID:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Microsoft Teams OAuth not configured."
|
||||
)
|
||||
|
|
@ -106,7 +106,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user)
|
|||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.TEAMS_CLIENT_ID,
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.TEAMS_REDIRECT_URI,
|
||||
"response_mode": "query",
|
||||
|
|
@ -181,8 +181,8 @@ async def teams_callback(
|
|||
|
||||
# Exchange authorization code for access token
|
||||
token_data = {
|
||||
"client_id": config.TEAMS_CLIENT_ID,
|
||||
"client_secret": config.TEAMS_CLIENT_SECRET,
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": config.TEAMS_REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
|
|
@ -403,8 +403,8 @@ async def refresh_teams_token(
|
|||
|
||||
# Microsoft uses oauth2/v2.0/token for token refresh
|
||||
refresh_data = {
|
||||
"client_id": config.TEAMS_CLIENT_ID,
|
||||
"client_secret": config.TEAMS_CLIENT_SECRET,
|
||||
"client_id": config.MICROSOFT_CLIENT_ID,
|
||||
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": " ".join(SCOPES),
|
||||
|
|
|
|||
|
|
@ -110,6 +110,14 @@ from .search_space import (
|
|||
SearchSpaceUpdate,
|
||||
SearchSpaceWithStats,
|
||||
)
|
||||
from .stripe import (
|
||||
CreateCheckoutSessionRequest,
|
||||
CreateCheckoutSessionResponse,
|
||||
PagePurchaseHistoryResponse,
|
||||
PagePurchaseRead,
|
||||
StripeStatusResponse,
|
||||
StripeWebhookResponse,
|
||||
)
|
||||
from .users import UserCreate, UserRead, UserUpdate
|
||||
from .video_presentations import (
|
||||
VideoPresentationBase,
|
||||
|
|
@ -128,6 +136,8 @@ __all__ = [
|
|||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
"CreateCheckoutSessionRequest",
|
||||
"CreateCheckoutSessionResponse",
|
||||
"DefaultSystemInstructionsResponse",
|
||||
# Document schemas
|
||||
"DocumentBase",
|
||||
|
|
@ -207,6 +217,8 @@ __all__ = [
|
|||
"NewLLMConfigPublic",
|
||||
"NewLLMConfigRead",
|
||||
"NewLLMConfigUpdate",
|
||||
"PagePurchaseHistoryResponse",
|
||||
"PagePurchaseRead",
|
||||
"PaginatedResponse",
|
||||
"PermissionInfo",
|
||||
"PermissionsListResponse",
|
||||
|
|
@ -236,6 +248,8 @@ __all__ = [
|
|||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceWithStats",
|
||||
"StripeStatusResponse",
|
||||
"StripeWebhookResponse",
|
||||
"ThreadHistoryLoadResponse",
|
||||
"ThreadListItem",
|
||||
"ThreadListResponse",
|
||||
|
|
|
|||
71
surfsense_backend/app/schemas/onedrive_auth_credentials.py
Normal file
71
surfsense_backend/app/schemas/onedrive_auth_credentials.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Microsoft OneDrive OAuth credentials schema."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class OneDriveAuthCredentialsBase(BaseModel):
|
||||
"""Microsoft OneDrive OAuth credentials."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
expires_in: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
scope: str | None = None
|
||||
user_email: str | None = None
|
||||
user_name: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return self.expires_at <= datetime.now(UTC)
|
||||
|
||||
@property
|
||||
def is_refreshable(self) -> bool:
|
||||
return self.refresh_token is not None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"access_token": self.access_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
"token_type": self.token_type,
|
||||
"expires_in": self.expires_in,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"scope": self.scope,
|
||||
"user_email": self.user_email,
|
||||
"user_name": self.user_name,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "OneDriveAuthCredentialsBase":
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"])
|
||||
return cls(
|
||||
access_token=data.get("access_token", ""),
|
||||
refresh_token=data.get("refresh_token"),
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_in=data.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=data.get("scope"),
|
||||
user_email=data.get("user_email"),
|
||||
user_name=data.get("user_name"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def ensure_aware_utc(cls, v):
|
||||
if isinstance(v, str):
|
||||
if v.endswith("Z"):
|
||||
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
||||
dt = datetime.fromisoformat(v)
|
||||
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
||||
if isinstance(v, datetime):
|
||||
return v if v.tzinfo else v.replace(tzinfo=UTC)
|
||||
return v
|
||||
36
surfsense_backend/app/schemas/prompts.py
Normal file
36
surfsense_backend/app/schemas/prompts.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PromptCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
prompt: str = Field(..., min_length=1)
|
||||
mode: str = Field(..., pattern="^(transform|explore)$")
|
||||
search_space_id: int | None = None
|
||||
is_public: bool = False
|
||||
|
||||
|
||||
class PromptUpdate(BaseModel):
|
||||
name: str | None = Field(None, min_length=1, max_length=200)
|
||||
prompt: str | None = Field(None, min_length=1)
|
||||
mode: str | None = Field(None, pattern="^(transform|explore)$")
|
||||
is_public: bool | None = None
|
||||
|
||||
|
||||
class PromptRead(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
prompt: str
|
||||
mode: str
|
||||
search_space_id: int | None
|
||||
is_public: bool
|
||||
version: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PublicPromptRead(PromptRead):
|
||||
author_name: str | None = None
|
||||
56
surfsense_backend/app/schemas/stripe.py
Normal file
56
surfsense_backend/app/schemas/stripe.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Schemas for Stripe-backed page purchases."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import PagePurchaseStatus
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
"""Request body for creating a page-purchase checkout session."""
|
||||
|
||||
quantity: int = Field(ge=1, le=100)
|
||||
search_space_id: int = Field(ge=1)
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe-hosted checkout URL."""
|
||||
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class StripeStatusResponse(BaseModel):
|
||||
"""Response describing Stripe page-buying availability."""
|
||||
|
||||
page_buying_enabled: bool
|
||||
|
||||
|
||||
class PagePurchaseRead(BaseModel):
|
||||
"""Serialized page-purchase record for purchase history."""
|
||||
|
||||
id: uuid.UUID
|
||||
stripe_checkout_session_id: str
|
||||
stripe_payment_intent_id: str | None = None
|
||||
quantity: int
|
||||
pages_granted: int
|
||||
amount_total: int | None = None
|
||||
currency: str | None = None
|
||||
status: PagePurchaseStatus
|
||||
completed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PagePurchaseHistoryResponse(BaseModel):
|
||||
"""Response containing the authenticated user's page purchases."""
|
||||
|
||||
purchases: list[PagePurchaseRead]
|
||||
|
||||
|
||||
class StripeWebhookResponse(BaseModel):
|
||||
"""Generic acknowledgement for Stripe webhook delivery."""
|
||||
|
||||
received: bool = True
|
||||
5
surfsense_backend/app/services/dropbox/__init__.py
Normal file
5
surfsense_backend/app/services/dropbox/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from app.services.dropbox.kb_sync_service import DropboxKBSyncService
|
||||
|
||||
__all__ = [
|
||||
"DropboxKBSyncService",
|
||||
]
|
||||
159
surfsense_backend/app/services/dropbox/kb_sync_service.py
Normal file
159
surfsense_backend/app/services/dropbox/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DropboxKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
file_path: str,
|
||||
web_url: str | None,
|
||||
content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = compute_identifier_hash(
|
||||
DocumentType.DROPBOX_FILE.value, file_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for Dropbox file %s already exists (doc_id=%s), skipping",
|
||||
file_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"Dropbox file: {file_name}"
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for Dropbox file %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
file_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"file_name": file_name,
|
||||
"document_type": "Dropbox File",
|
||||
"connector_type": "Dropbox",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = f"Dropbox File: {file_name}\n\n{indexable_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=file_name,
|
||||
document_type=DocumentType.DROPBOX_FILE,
|
||||
document_metadata={
|
||||
"dropbox_file_id": file_id,
|
||||
"dropbox_file_name": file_name,
|
||||
"dropbox_path": file_path,
|
||||
"web_url": web_url,
|
||||
"source_connector": "dropbox",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, file=%s, chunks=%d",
|
||||
document.id,
|
||||
file_name,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for file %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
file_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for file %s: %s",
|
||||
file_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
5
surfsense_backend/app/services/onedrive/__init__.py
Normal file
5
surfsense_backend/app/services/onedrive/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from app.services.onedrive.kb_sync_service import OneDriveKBSyncService
|
||||
|
||||
__all__ = [
|
||||
"OneDriveKBSyncService",
|
||||
]
|
||||
160
surfsense_backend/app/services/onedrive/kb_sync_service.py
Normal file
160
surfsense_backend/app/services/onedrive/kb_sync_service.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OneDriveKBSyncService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def sync_after_create(
|
||||
self,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
mime_type: str,
|
||||
web_url: str | None,
|
||||
content: str | None,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document_by_hash,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
try:
|
||||
unique_hash = compute_identifier_hash(
|
||||
DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id
|
||||
)
|
||||
|
||||
existing = await check_document_by_unique_identifier(
|
||||
self.db_session, unique_hash
|
||||
)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Document for OneDrive file %s already exists (doc_id=%s), skipping",
|
||||
file_id,
|
||||
existing.id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
indexable_content = (content or "").strip()
|
||||
if not indexable_content:
|
||||
indexable_content = f"OneDrive file: {file_name} (type: {mime_type})"
|
||||
|
||||
content_hash = generate_content_hash(indexable_content, search_space_id)
|
||||
|
||||
with self.db_session.no_autoflush:
|
||||
dup = await check_duplicate_document_by_hash(
|
||||
self.db_session, content_hash
|
||||
)
|
||||
if dup:
|
||||
logger.info(
|
||||
"Content-hash collision for OneDrive file %s — identical content "
|
||||
"exists in doc %s. Using unique_identifier_hash as content_hash.",
|
||||
file_id,
|
||||
dup.id,
|
||||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session,
|
||||
user_id,
|
||||
search_space_id,
|
||||
disable_streaming=True,
|
||||
)
|
||||
|
||||
doc_metadata_for_summary = {
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"document_type": "OneDrive File",
|
||||
"connector_type": "OneDrive",
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
indexable_content, user_llm, doc_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
logger.warning("No LLM configured — using fallback summary")
|
||||
summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(indexable_content)
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
document = Document(
|
||||
title=file_name,
|
||||
document_type=DocumentType.ONEDRIVE_FILE,
|
||||
document_metadata={
|
||||
"onedrive_file_id": file_id,
|
||||
"onedrive_file_name": file_name,
|
||||
"onedrive_mime_type": mime_type,
|
||||
"web_url": web_url,
|
||||
"source_connector": "onedrive",
|
||||
"indexed_at": now_str,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content=summary_content,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_hash,
|
||||
embedding=summary_embedding,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
source_markdown=content,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
self.db_session.add(document)
|
||||
await self.db_session.flush()
|
||||
await safe_set_chunks(self.db_session, document, chunks)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"KB sync after create succeeded: doc_id=%s, file=%s, chunks=%d",
|
||||
document.id,
|
||||
file_name,
|
||||
len(chunks),
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in error_str
|
||||
or "uniqueviolationerror" in error_str
|
||||
):
|
||||
logger.warning(
|
||||
"Duplicate constraint hit during KB sync for file %s. "
|
||||
"Rolling back — periodic indexer will handle it. Error: %s",
|
||||
file_id,
|
||||
e,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": "Duplicate document detected"}
|
||||
|
||||
logger.error(
|
||||
"KB sync after create failed for file %s: %s",
|
||||
file_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
@ -526,6 +526,102 @@ async def _index_google_drive_files(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_onedrive_files", bind=True)
|
||||
def index_onedrive_files_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index OneDrive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_onedrive_files(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Index OneDrive folders and files with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_onedrive_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_onedrive_indexing(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_dropbox_files", bind=True)
|
||||
def index_dropbox_files_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index Dropbox folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_dropbox_files(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Index Dropbox folders and files with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_dropbox_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_dropbox_indexing(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_discord_messages", bind=True)
|
||||
def index_discord_messages_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
"""Reconcile pending Stripe page purchases that might miss webhook fulfillment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from stripe import StripeClient, StripeError
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import PagePurchase, PagePurchaseStatus
|
||||
from app.routes import stripe_routes
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_stripe_client() -> StripeClient | None:
|
||||
"""Return a Stripe client for reconciliation, or None when disabled."""
|
||||
if not config.STRIPE_SECRET_KEY:
|
||||
logger.warning(
|
||||
"Stripe reconciliation skipped because STRIPE_SECRET_KEY is not configured."
|
||||
)
|
||||
return None
|
||||
return StripeClient(config.STRIPE_SECRET_KEY)
|
||||
|
||||
|
||||
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
|
||||
def reconcile_pending_stripe_page_purchases_task():
|
||||
"""Recover paid purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_page_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _reconcile_pending_page_purchases() -> None:
|
||||
"""Reconcile stale pending page purchases against Stripe source of truth.
|
||||
|
||||
Stripe retries webhook delivery automatically, but best practice is to add an
|
||||
application-level reconciliation path in case all retries fail or the endpoint
|
||||
is unavailable for an extended window.
|
||||
"""
|
||||
stripe_client = get_stripe_client()
|
||||
if stripe_client is None:
|
||||
return
|
||||
|
||||
lookback_minutes = max(config.STRIPE_RECONCILIATION_LOOKBACK_MINUTES, 0)
|
||||
batch_size = max(config.STRIPE_RECONCILIATION_BATCH_SIZE, 1)
|
||||
cutoff = datetime.now(UTC) - timedelta(minutes=lookback_minutes)
|
||||
|
||||
async with get_celery_session_maker()() as db_session:
|
||||
pending_purchases = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(PagePurchase)
|
||||
.where(
|
||||
PagePurchase.status == PagePurchaseStatus.PENDING,
|
||||
PagePurchase.created_at <= cutoff,
|
||||
)
|
||||
.order_by(PagePurchase.created_at.asc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending_purchases:
|
||||
logger.debug(
|
||||
"Stripe reconciliation found no pending purchases older than %s minutes.",
|
||||
lookback_minutes,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Stripe reconciliation checking %s pending purchases (cutoff=%s, batch=%s).",
|
||||
len(pending_purchases),
|
||||
lookback_minutes,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
fulfilled_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for purchase in pending_purchases:
|
||||
checkout_session_id = purchase.stripe_checkout_session_id
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.retrieve(
|
||||
checkout_session_id
|
||||
)
|
||||
except StripeError:
|
||||
logger.exception(
|
||||
"Stripe reconciliation failed to retrieve checkout session %s",
|
||||
checkout_session_id,
|
||||
)
|
||||
await db_session.rollback()
|
||||
continue
|
||||
|
||||
payment_status = getattr(checkout_session, "payment_status", None)
|
||||
session_status = getattr(checkout_session, "status", None)
|
||||
|
||||
try:
|
||||
if payment_status in {"paid", "no_payment_required"}:
|
||||
await stripe_routes._fulfill_completed_purchase(
|
||||
db_session, checkout_session
|
||||
)
|
||||
fulfilled_count += 1
|
||||
elif session_status == "expired":
|
||||
await stripe_routes._mark_purchase_failed(
|
||||
db_session, str(checkout_session.id)
|
||||
)
|
||||
failed_count += 1
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Stripe reconciliation failed while processing checkout session %s",
|
||||
checkout_session_id,
|
||||
)
|
||||
await db_session.rollback()
|
||||
|
||||
logger.info(
|
||||
"Stripe reconciliation completed. fulfilled=%s failed=%s checked=%s",
|
||||
fulfilled_count,
|
||||
failed_count,
|
||||
len(pending_purchases),
|
||||
)
|
||||
|
|
@ -1023,6 +1023,10 @@ async def _stream_agent_events(
|
|||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
"create_onedrive_file",
|
||||
"delete_onedrive_file",
|
||||
"create_dropbox_file",
|
||||
"delete_dropbox_file",
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
|
|
@ -1073,6 +1077,37 @@ async def _stream_agent_events(
|
|||
"thread_id": thread_id_str,
|
||||
},
|
||||
)
|
||||
elif tool_name == "web_search":
|
||||
xml = (
|
||||
tool_output.get("result", str(tool_output))
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
citations: dict[str, dict[str, str]] = {}
|
||||
for m in re.finditer(
|
||||
r"<title><!\[CDATA\[(.*?)\]\]></title>\s*<url><!\[CDATA\[(.*?)\]\]></url>",
|
||||
xml,
|
||||
):
|
||||
title, url = m.group(1).strip(), m.group(2).strip()
|
||||
if url.startswith("http") and url not in citations:
|
||||
citations[url] = {"title": title}
|
||||
for m in re.finditer(
|
||||
r"<chunk\s+id='([^']*)'><!\[CDATA\[([\s\S]*?)\]\]></chunk>",
|
||||
xml,
|
||||
):
|
||||
chunk_url, content = m.group(1).strip(), m.group(2).strip()
|
||||
if (
|
||||
chunk_url.startswith("http")
|
||||
and chunk_url in citations
|
||||
and content
|
||||
):
|
||||
citations[chunk_url]["snippet"] = (
|
||||
content[:200] + "…" if len(content) > 200 else content
|
||||
)
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{"status": "completed", "citations": citations},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,527 @@
|
|||
"""Dropbox indexer using the shared IndexingPipelineService.
|
||||
|
||||
File-level pre-filter (_should_skip_file) handles content_hash and
|
||||
server_modified checks. download_and_extract_content() returns
|
||||
markdown which is fed into ConnectorDocument -> pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from sqlalchemy import String, cast, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.dropbox import (
|
||||
DropboxClient,
|
||||
download_and_extract_content,
|
||||
get_file_by_path,
|
||||
get_files_in_folder,
|
||||
)
|
||||
from app.connectors.dropbox.file_types import should_skip_file as skip_item
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
get_connector_by_id,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _should_skip_file(
|
||||
session: AsyncSession,
|
||||
file: dict,
|
||||
search_space_id: int,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Pre-filter: detect unchanged / rename-only files."""
|
||||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
return True, "folder/non-downloadable"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.DROPBOX_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
cast(Document.document_metadata["dropbox_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.unique_identifier_hash = primary_hash
|
||||
logger.debug(f"Found Dropbox doc by metadata for file_id: {file_id}")
|
||||
|
||||
if not existing:
|
||||
return False, None
|
||||
|
||||
incoming_content_hash = file.get("content_hash")
|
||||
meta = existing.document_metadata or {}
|
||||
stored_content_hash = meta.get("content_hash")
|
||||
|
||||
incoming_mtime = file.get("server_modified")
|
||||
stored_mtime = meta.get("modified_time")
|
||||
|
||||
content_unchanged = False
|
||||
if incoming_content_hash and stored_content_hash:
|
||||
content_unchanged = incoming_content_hash == stored_content_hash
|
||||
elif incoming_content_hash and not stored_content_hash:
|
||||
return False, None
|
||||
elif not incoming_content_hash and incoming_mtime and stored_mtime:
|
||||
content_unchanged = incoming_mtime == stored_mtime
|
||||
elif not incoming_content_hash:
|
||||
return False, None
|
||||
|
||||
if not content_unchanged:
|
||||
return False, None
|
||||
|
||||
old_name = meta.get("dropbox_file_name")
|
||||
if old_name and old_name != file_name:
|
||||
existing.title = file_name
|
||||
if not existing.document_metadata:
|
||||
existing.document_metadata = {}
|
||||
existing.document_metadata["dropbox_file_name"] = file_name
|
||||
if incoming_mtime:
|
||||
existing.document_metadata["modified_time"] = incoming_mtime
|
||||
flag_modified(existing, "document_metadata")
|
||||
await session.commit()
|
||||
logger.info(f"Rename-only update: '{old_name}' -> '{file_name}'")
|
||||
return True, f"File renamed: '{old_name}' -> '{file_name}'"
|
||||
|
||||
if not DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
return True, "skipped (previously failed)"
|
||||
return True, "unchanged"
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
file: dict,
|
||||
markdown: str,
|
||||
dropbox_metadata: dict,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
metadata = {
|
||||
**dropbox_metadata,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "Dropbox File",
|
||||
"connector_type": "Dropbox",
|
||||
}
|
||||
|
||||
fallback_summary = f"File: {file_name}\n\n{markdown[:4000]}"
|
||||
|
||||
return ConnectorDocument(
|
||||
title=file_name,
|
||||
source_markdown=markdown,
|
||||
unique_id=file_id,
|
||||
document_type=DocumentType.DROPBOX_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _download_files_parallel(
|
||||
dropbox_client: DropboxClient,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 3,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel. Returns (docs, failed_count)."""
|
||||
results: list[ConnectorDocument] = []
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
last_heartbeat = time.time()
|
||||
completed_count = 0
|
||||
hb_lock = asyncio.Lock()
|
||||
|
||||
async def _download_one(file: dict) -> ConnectorDocument | None:
|
||||
nonlocal last_heartbeat, completed_count
|
||||
async with sem:
|
||||
markdown, db_metadata, error = await download_and_extract_content(
|
||||
dropbox_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
file_name = file.get("name", "Unknown")
|
||||
reason = error or "empty content"
|
||||
logger.warning(f"Download/ETL failed for {file_name}: {reason}")
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file,
|
||||
markdown,
|
||||
db_metadata,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
async with hb_lock:
|
||||
completed_count += 1
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat(completed_count)
|
||||
last_heartbeat = now
|
||||
return doc
|
||||
|
||||
tasks = [_download_one(f) for f in files]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed = 0
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Exception) or outcome is None:
|
||||
failed += 1
|
||||
else:
|
||||
results.append(outcome)
|
||||
|
||||
return results, failed
|
||||
|
||||
|
||||
async def _download_and_index(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Parallel download then parallel indexing. Returns (batch_indexed, total_failed)."""
|
||||
connector_docs, download_failed = await _download_files_parallel(
|
||||
dropbox_client,
|
||||
files,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
batch_indexed = 0
|
||||
batch_failed = 0
|
||||
if connector_docs:
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _index_full_scan(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_path: str,
|
||||
folder_name: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
include_subfolders: bool = True,
|
||||
incremental_sync: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
{
|
||||
"stage": "full_scan",
|
||||
"folder_path": folder_path,
|
||||
"include_subfolders": include_subfolders,
|
||||
"incremental_sync": incremental_sync,
|
||||
},
|
||||
)
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
dropbox_client,
|
||||
folder_path,
|
||||
include_subfolders=include_subfolders,
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(
|
||||
f"Dropbox authentication failed. Please re-authenticate. (Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to list Dropbox files: {error}")
|
||||
|
||||
for file in all_files[:max_files]:
|
||||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
dropbox_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
file_paths: list[tuple[str, str | None]],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
incremental_sync: bool = True,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
|
||||
for file_path, file_name in file_paths:
|
||||
file, error = await get_file_by_path(dropbox_client, file_path)
|
||||
if error or not file:
|
||||
display = file_name or file_path
|
||||
errors.append(f"File '{display}': {error or 'File not found'}")
|
||||
continue
|
||||
|
||||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
dropbox_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
async def index_dropbox_files(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index Dropbox files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
{
|
||||
"folders": [{"path": "...", "name": "..."}, ...],
|
||||
"files": [{"path": "...", "name": "..."}, ...],
|
||||
"indexing_options": {
|
||||
"max_files": 500,
|
||||
"incremental_sync": true,
|
||||
"include_subfolders": true,
|
||||
}
|
||||
}
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="dropbox_files_indexing",
|
||||
source="connector_indexing_task",
|
||||
message=f"Starting Dropbox indexing for connector {connector_id}",
|
||||
metadata={"connector_id": connector_id, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.DROPBOX_CONNECTOR
|
||||
)
|
||||
if not connector:
|
||||
error_msg = f"Dropbox connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
error_msg = "SECRET_KEY not configured but credentials are encrypted"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
dropbox_client = DropboxClient(session, connector_id)
|
||||
|
||||
indexing_options = items_dict.get("indexing_options", {})
|
||||
max_files = indexing_options.get("max_files", 500)
|
||||
incremental_sync = indexing_options.get("incremental_sync", True)
|
||||
include_subfolders = indexing_options.get("include_subfolders", True)
|
||||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
file_tuples = [
|
||||
(f.get("path", f.get("path_lower", f.get("id", ""))), f.get("name"))
|
||||
for f in selected_files
|
||||
]
|
||||
indexed, skipped, file_errors = await _index_selected_files(
|
||||
dropbox_client,
|
||||
session,
|
||||
file_tuples,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector_enable_summary,
|
||||
incremental_sync=incremental_sync,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
if file_errors:
|
||||
logger.warning(
|
||||
f"File indexing errors for connector {connector_id}: {file_errors}"
|
||||
)
|
||||
|
||||
folders = items_dict.get("folders", [])
|
||||
for folder in folders:
|
||||
folder_path = folder.get(
|
||||
"path", folder.get("path_lower", folder.get("id", ""))
|
||||
)
|
||||
folder_name = folder.get("name", "Root")
|
||||
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_path,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
incremental_sync=incremental_sync,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
if total_indexed > 0 or folders:
|
||||
await update_connector_last_indexed(session, connector, True)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Dropbox indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
)
|
||||
logger.info(
|
||||
f"Dropbox indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Database error during Dropbox indexing for connector {connector_id}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to index Dropbox files for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Dropbox files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index Dropbox files: {e!s}"
|
||||
|
|
@ -0,0 +1,686 @@
|
|||
"""OneDrive indexer using the shared IndexingPipelineService.
|
||||
|
||||
File-level pre-filter (_should_skip_file) handles hash/modifiedDateTime
|
||||
checks and rename-only detection. download_and_extract_content()
|
||||
returns markdown which is fed into ConnectorDocument -> pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from sqlalchemy import String, cast, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.onedrive import (
|
||||
OneDriveClient,
|
||||
download_and_extract_content,
|
||||
get_file_by_id,
|
||||
get_files_in_folder,
|
||||
)
|
||||
from app.connectors.onedrive.file_types import should_skip_file as skip_item
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
get_connector_by_id,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _should_skip_file(
|
||||
session: AsyncSession,
|
||||
file: dict,
|
||||
search_space_id: int,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Pre-filter: detect unchanged / rename-only files."""
|
||||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
return True, "folder/onenote/remote"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
cast(Document.document_metadata["onedrive_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.unique_identifier_hash = primary_hash
|
||||
logger.debug(f"Found OneDrive doc by metadata for file_id: {file_id}")
|
||||
|
||||
if not existing:
|
||||
return False, None
|
||||
|
||||
incoming_mtime = file.get("lastModifiedDateTime")
|
||||
meta = existing.document_metadata or {}
|
||||
stored_mtime = meta.get("modified_time")
|
||||
|
||||
file_info = file.get("file", {})
|
||||
file_hashes = file_info.get("hashes", {})
|
||||
incoming_hash = file_hashes.get("sha256Hash") or file_hashes.get("quickXorHash")
|
||||
stored_hash = meta.get("sha256_hash") or meta.get("quick_xor_hash")
|
||||
|
||||
content_unchanged = False
|
||||
if incoming_hash and stored_hash:
|
||||
content_unchanged = incoming_hash == stored_hash
|
||||
elif incoming_hash and not stored_hash:
|
||||
return False, None
|
||||
elif not incoming_hash and incoming_mtime and stored_mtime:
|
||||
content_unchanged = incoming_mtime == stored_mtime
|
||||
elif not incoming_hash:
|
||||
return False, None
|
||||
|
||||
if not content_unchanged:
|
||||
return False, None
|
||||
|
||||
old_name = meta.get("onedrive_file_name")
|
||||
if old_name and old_name != file_name:
|
||||
existing.title = file_name
|
||||
if not existing.document_metadata:
|
||||
existing.document_metadata = {}
|
||||
existing.document_metadata["onedrive_file_name"] = file_name
|
||||
if incoming_mtime:
|
||||
existing.document_metadata["modified_time"] = incoming_mtime
|
||||
flag_modified(existing, "document_metadata")
|
||||
await session.commit()
|
||||
logger.info(f"Rename-only update: '{old_name}' -> '{file_name}'")
|
||||
return True, f"File renamed: '{old_name}' -> '{file_name}'"
|
||||
|
||||
if not DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
return True, "skipped (previously failed)"
|
||||
return True, "unchanged"
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
file: dict,
|
||||
markdown: str,
|
||||
onedrive_metadata: dict,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
metadata = {
|
||||
**onedrive_metadata,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "OneDrive File",
|
||||
"connector_type": "OneDrive",
|
||||
}
|
||||
|
||||
fallback_summary = f"File: {file_name}\n\n{markdown[:4000]}"
|
||||
|
||||
return ConnectorDocument(
|
||||
title=file_name,
|
||||
source_markdown=markdown,
|
||||
unique_id=file_id,
|
||||
document_type=DocumentType.ONEDRIVE_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _download_files_parallel(
|
||||
onedrive_client: OneDriveClient,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 3,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel. Returns (docs, failed_count)."""
|
||||
results: list[ConnectorDocument] = []
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
last_heartbeat = time.time()
|
||||
completed_count = 0
|
||||
hb_lock = asyncio.Lock()
|
||||
|
||||
async def _download_one(file: dict) -> ConnectorDocument | None:
|
||||
nonlocal last_heartbeat, completed_count
|
||||
async with sem:
|
||||
markdown, od_metadata, error = await download_and_extract_content(
|
||||
onedrive_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
file_name = file.get("name", "Unknown")
|
||||
reason = error or "empty content"
|
||||
logger.warning(f"Download/ETL failed for {file_name}: {reason}")
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file,
|
||||
markdown,
|
||||
od_metadata,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
async with hb_lock:
|
||||
completed_count += 1
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat(completed_count)
|
||||
last_heartbeat = now
|
||||
return doc
|
||||
|
||||
tasks = [_download_one(f) for f in files]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed = 0
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Exception) or outcome is None:
|
||||
failed += 1
|
||||
else:
|
||||
results.append(outcome)
|
||||
|
||||
return results, failed
|
||||
|
||||
|
||||
async def _download_and_index(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Parallel download then parallel indexing. Returns (batch_indexed, total_failed)."""
|
||||
connector_docs, download_failed = await _download_files_parallel(
|
||||
onedrive_client,
|
||||
files,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
batch_indexed = 0
|
||||
batch_failed = 0
|
||||
if connector_docs:
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
|
||||
"""Remove a document that was deleted in OneDrive."""
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
cast(Document.document_metadata["onedrive_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
logger.info(f"Removed deleted OneDrive file document: {file_id}")
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
file_ids: list[tuple[str, str | None]],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(onedrive_client, file_id)
|
||||
if error or not file:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': {error or 'File not found'}")
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
onedrive_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scan strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _index_full_scan(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_id: str,
|
||||
folder_name: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
include_subfolders: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
{
|
||||
"stage": "full_scan",
|
||||
"folder_id": folder_id,
|
||||
"include_subfolders": include_subfolders,
|
||||
},
|
||||
)
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
onedrive_client,
|
||||
folder_id,
|
||||
include_subfolders=include_subfolders,
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(
|
||||
f"OneDrive authentication failed. Please re-authenticate. (Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to list OneDrive files: {error}")
|
||||
|
||||
for file in all_files[:max_files]:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
onedrive_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_id: str | None,
|
||||
delta_link: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link)."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
"Starting delta sync",
|
||||
{"stage": "delta_sync"},
|
||||
)
|
||||
|
||||
changes, new_delta_link, error = await onedrive_client.get_delta(
|
||||
folder_id=folder_id, delta_link=delta_link
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(
|
||||
f"OneDrive authentication failed. Please re-authenticate. (Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to fetch OneDrive changes: {error}")
|
||||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, new_delta_link
|
||||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
for change in changes:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
files_processed += 1
|
||||
|
||||
if change.get("deleted"):
|
||||
fid = change.get("id")
|
||||
if fid:
|
||||
await _remove_document(session, fid, search_space_id)
|
||||
continue
|
||||
|
||||
if "folder" in change:
|
||||
continue
|
||||
|
||||
if not change.get("file"):
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, change, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(change)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
onedrive_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
)
|
||||
return indexed, skipped, new_delta_link
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def index_onedrive_files(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index OneDrive files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
{
|
||||
"folders": [{"id": "...", "name": "..."}, ...],
|
||||
"files": [{"id": "...", "name": "..."}, ...],
|
||||
"indexing_options": {"max_files": 500, "include_subfolders": true, "use_delta_sync": true}
|
||||
}
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="onedrive_files_indexing",
|
||||
source="connector_indexing_task",
|
||||
message=f"Starting OneDrive indexing for connector {connector_id}",
|
||||
metadata={"connector_id": connector_id, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.ONEDRIVE_CONNECTOR
|
||||
)
|
||||
if not connector:
|
||||
error_msg = f"OneDrive connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
error_msg = "SECRET_KEY not configured but credentials are encrypted"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
error_msg,
|
||||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
||||
indexing_options = items_dict.get("indexing_options", {})
|
||||
max_files = indexing_options.get("max_files", 500)
|
||||
include_subfolders = indexing_options.get("include_subfolders", True)
|
||||
use_delta_sync = indexing_options.get("use_delta_sync", True)
|
||||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
|
||||
# Index selected individual files
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
file_tuples = [(f["id"], f.get("name")) for f in selected_files]
|
||||
indexed, skipped, _errors = await _index_selected_files(
|
||||
onedrive_client,
|
||||
session,
|
||||
file_tuples,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Index selected folders
|
||||
folders = items_dict.get("folders", [])
|
||||
for folder in folders:
|
||||
folder_id = folder.get("id", "root")
|
||||
folder_name = folder.get("name", "Root")
|
||||
|
||||
folder_delta_links = connector.config.get("folder_delta_links", {})
|
||||
delta_link = folder_delta_links.get(folder_id)
|
||||
can_use_delta = use_delta_sync and delta_link and connector.last_indexed_at
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, new_delta_link = await _index_with_delta_sync(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_id,
|
||||
delta_link,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
if "folder_delta_links" not in connector.config:
|
||||
connector.config["folder_delta_links"] = {}
|
||||
connector.config["folder_delta_links"][folder_id] = new_delta_link
|
||||
flag_modified(connector, "config")
|
||||
|
||||
# Reconciliation full scan
|
||||
ri, rs = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_id,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += ri
|
||||
total_skipped += rs
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_id,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Store new delta link for this folder
|
||||
_, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id)
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
if "folder_delta_links" not in connector.config:
|
||||
connector.config["folder_delta_links"] = {}
|
||||
connector.config["folder_delta_links"][folder_id] = new_delta_link
|
||||
flag_modified(connector, "config")
|
||||
|
||||
if total_indexed > 0 or folders:
|
||||
await update_connector_last_indexed(session, connector, True)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed OneDrive indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
)
|
||||
logger.info(
|
||||
f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Database error during OneDrive indexing for connector {connector_id}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to index OneDrive files for connector {connector_id}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}"
|
||||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import update
|
|||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
|
|
@ -25,6 +26,7 @@ from app.db import (
|
|||
get_default_roles_config,
|
||||
get_user_db,
|
||||
)
|
||||
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
||||
from app.utils.refresh_tokens import create_refresh_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -188,6 +190,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||
)
|
||||
session.add(owner_membership)
|
||||
|
||||
for default in SYSTEM_PROMPT_DEFAULTS:
|
||||
session.add(
|
||||
Prompt(
|
||||
user_id=user.id,
|
||||
default_prompt_slug=default["slug"],
|
||||
name=default["name"],
|
||||
prompt=default["prompt"],
|
||||
mode=default["mode"],
|
||||
version=default["version"],
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Created default search space (ID: {default_search_space.id}) for user {user.id}"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ BASE_NAME_FOR_TYPE = {
|
|||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: "Slack",
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "Microsoft Teams",
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive",
|
||||
SearchSourceConnectorType.DROPBOX_CONNECTOR: "Dropbox",
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: "Notion",
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear",
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: "Jira",
|
||||
|
|
@ -61,6 +63,12 @@ def extract_identifier_from_credentials(
|
|||
if connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
||||
return credentials.get("tenant_name")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR:
|
||||
return credentials.get("user_email")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.DROPBOX_CONNECTOR:
|
||||
return credentials.get("user_email")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
return credentials.get("workspace_name")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "surf-new-backend"
|
||||
version = "0.0.13"
|
||||
version = "0.0.14"
|
||||
description = "SurfSense Backend"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
|
@ -74,6 +74,7 @@ dependencies = [
|
|||
"langgraph>=1.1.3",
|
||||
"langchain-community>=0.4.1",
|
||||
"deepagents>=0.4.12",
|
||||
"stripe>=15.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,494 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport
|
||||
|
||||
from app.app import app
|
||||
from app.routes import stripe_routes
|
||||
from app.tasks.celery_tasks import stripe_reconciliation_task
|
||||
from tests.conftest import TEST_DATABASE_URL
|
||||
from tests.utils.helpers import TEST_EMAIL, TEST_PASSWORD, auth_headers
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
async def _execute(query: str, *args) -> None:
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
await conn.execute(query, *args)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def _fetchrow(query: str, *args):
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
return await conn.fetchrow(query, *args)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def _get_user_id(email: str) -> str:
|
||||
row = await _fetchrow('SELECT id FROM "user" WHERE email = $1', email)
|
||||
assert row is not None, f"User {email!r} not found"
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
async def _get_pages_limit(email: str) -> int:
|
||||
row = await _fetchrow('SELECT pages_limit FROM "user" WHERE email = $1', email)
|
||||
assert row is not None, f"User {email!r} not found"
|
||||
return row["pages_limit"]
|
||||
|
||||
|
||||
def _extract_access_token(response: httpx.Response) -> str | None:
|
||||
if response.status_code == 200:
|
||||
return response.json()["access_token"]
|
||||
|
||||
if response.status_code == 302:
|
||||
location = response.headers.get("location", "")
|
||||
return parse_qs(urlparse(location).query).get("token", [None])[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _authenticate_test_user(client: httpx.AsyncClient) -> str:
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
token = _extract_access_token(response)
|
||||
if token:
|
||||
return token
|
||||
|
||||
reg_response = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
assert reg_response.status_code == 201, (
|
||||
f"Registration failed ({reg_response.status_code}): {reg_response.text}"
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
token = _extract_access_token(response)
|
||||
assert token, f"Login failed ({response.status_code}): {response.text}"
|
||||
return token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def auth_token(_ensure_tables) -> str:
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
|
||||
) as client:
|
||||
return await _authenticate_test_user(client)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def headers(auth_token: str) -> dict[str, str]:
|
||||
return auth_headers(auth_token)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _cleanup_page_purchases():
|
||||
await _execute("DELETE FROM page_purchases")
|
||||
yield
|
||||
await _execute("DELETE FROM page_purchases")
|
||||
|
||||
|
||||
class _FakeCreateStripeClient:
|
||||
def __init__(self, checkout_session):
|
||||
self.checkout_session = checkout_session
|
||||
self.last_params = None
|
||||
self.v1 = SimpleNamespace(
|
||||
checkout=SimpleNamespace(
|
||||
sessions=SimpleNamespace(create=self._create_session)
|
||||
)
|
||||
)
|
||||
|
||||
def _create_session(self, *, params):
|
||||
self.last_params = params
|
||||
return self.checkout_session
|
||||
|
||||
|
||||
class _FakeWebhookStripeClient:
|
||||
def __init__(self, event):
|
||||
self.event = event
|
||||
self.last_payload = None
|
||||
self.last_signature = None
|
||||
self.last_secret = None
|
||||
|
||||
def construct_event(self, payload, signature, secret):
|
||||
self.last_payload = payload
|
||||
self.last_signature = signature
|
||||
self.last_secret = secret
|
||||
return self.event
|
||||
|
||||
|
||||
class _FakeReconciliationStripeClient:
|
||||
def __init__(self, checkout_session):
|
||||
self.checkout_session = checkout_session
|
||||
self.requested_ids = []
|
||||
self.v1 = SimpleNamespace(
|
||||
checkout=SimpleNamespace(
|
||||
sessions=SimpleNamespace(retrieve=self._retrieve_session)
|
||||
)
|
||||
)
|
||||
|
||||
def _retrieve_session(self, checkout_session_id: str):
|
||||
self.requested_ids.append(checkout_session_id)
|
||||
return self.checkout_session
|
||||
|
||||
|
||||
class TestStripeCheckoutSessionCreation:
|
||||
async def test_get_status_reflects_backend_toggle(
|
||||
self, client, headers, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False)
|
||||
disabled_response = await client.get("/api/v1/stripe/status", headers=headers)
|
||||
assert disabled_response.status_code == 200, disabled_response.text
|
||||
assert disabled_response.json() == {"page_buying_enabled": False}
|
||||
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", True)
|
||||
enabled_response = await client.get("/api/v1/stripe/status", headers=headers)
|
||||
assert enabled_response.status_code == 200, enabled_response.text
|
||||
assert enabled_response.json() == {"page_buying_enabled": True}
|
||||
|
||||
async def test_create_checkout_session_records_pending_purchase(
|
||||
self,
|
||||
client,
|
||||
headers,
|
||||
search_space_id: int,
|
||||
monkeypatch,
|
||||
):
|
||||
checkout_session = SimpleNamespace(
|
||||
id="cs_test_create_123",
|
||||
url="https://checkout.stripe.test/cs_test_create_123",
|
||||
payment_intent=None,
|
||||
amount_total=None,
|
||||
currency=None,
|
||||
)
|
||||
fake_client = _FakeCreateStripeClient(checkout_session)
|
||||
|
||||
monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: fake_client)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000")
|
||||
monkeypatch.setattr(
|
||||
stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000"
|
||||
)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/stripe/create-checkout-session",
|
||||
headers=headers,
|
||||
json={"quantity": 2, "search_space_id": search_space_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"checkout_url": checkout_session.url}
|
||||
assert fake_client.last_params is not None
|
||||
assert fake_client.last_params["mode"] == "payment"
|
||||
assert fake_client.last_params["line_items"] == [
|
||||
{"price": "price_pages_1000", "quantity": 2}
|
||||
]
|
||||
assert (
|
||||
fake_client.last_params["success_url"]
|
||||
== f"http://localhost:3000/dashboard/{search_space_id}/purchase-success"
|
||||
)
|
||||
assert (
|
||||
fake_client.last_params["cancel_url"]
|
||||
== f"http://localhost:3000/dashboard/{search_space_id}/purchase-cancel"
|
||||
)
|
||||
|
||||
purchase = await _fetchrow(
|
||||
"""
|
||||
SELECT quantity, pages_granted, status
|
||||
FROM page_purchases
|
||||
WHERE stripe_checkout_session_id = $1
|
||||
""",
|
||||
checkout_session.id,
|
||||
)
|
||||
assert purchase is not None
|
||||
assert purchase["quantity"] == 2
|
||||
assert purchase["pages_granted"] == 2000
|
||||
assert purchase["status"] == "PENDING"
|
||||
|
||||
async def test_create_checkout_session_returns_503_when_buying_disabled(
|
||||
self,
|
||||
client,
|
||||
headers,
|
||||
search_space_id: int,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/stripe/create-checkout-session",
|
||||
headers=headers,
|
||||
json={"quantity": 2, "search_space_id": search_space_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 503, response.text
|
||||
assert (
|
||||
response.json()["detail"] == "Page purchases are temporarily unavailable."
|
||||
)
|
||||
|
||||
purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases")
|
||||
assert purchase_count is not None
|
||||
assert purchase_count["count"] == 0
|
||||
|
||||
|
||||
class TestStripeWebhookFulfillment:
|
||||
async def test_webhook_grants_pages_once(
|
||||
self,
|
||||
client,
|
||||
headers,
|
||||
search_space_id: int,
|
||||
page_limits,
|
||||
monkeypatch,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=100)
|
||||
|
||||
checkout_session = SimpleNamespace(
|
||||
id="cs_test_webhook_123",
|
||||
url="https://checkout.stripe.test/cs_test_webhook_123",
|
||||
payment_intent=None,
|
||||
amount_total=None,
|
||||
currency=None,
|
||||
)
|
||||
create_client = _FakeCreateStripeClient(checkout_session)
|
||||
|
||||
monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000")
|
||||
monkeypatch.setattr(
|
||||
stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000"
|
||||
)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000)
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/stripe/create-checkout-session",
|
||||
headers=headers,
|
||||
json={"quantity": 3, "search_space_id": search_space_id},
|
||||
)
|
||||
assert create_response.status_code == 200, create_response.text
|
||||
|
||||
initial_limit = await _get_pages_limit(TEST_EMAIL)
|
||||
assert initial_limit == 100
|
||||
|
||||
user_id = await _get_user_id(TEST_EMAIL)
|
||||
webhook_checkout_session = SimpleNamespace(
|
||||
id=checkout_session.id,
|
||||
payment_status="paid",
|
||||
payment_intent="pi_test_123",
|
||||
amount_total=300,
|
||||
currency="usd",
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"quantity": "3",
|
||||
"pages_per_unit": "1000",
|
||||
},
|
||||
)
|
||||
event = SimpleNamespace(
|
||||
type="checkout.session.completed",
|
||||
data=SimpleNamespace(object=webhook_checkout_session),
|
||||
)
|
||||
webhook_client = _FakeWebhookStripeClient(event)
|
||||
|
||||
monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: webhook_client)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_WEBHOOK_SECRET", "whsec_test")
|
||||
|
||||
first_response = await client.post(
|
||||
"/api/v1/stripe/webhook",
|
||||
headers={"Stripe-Signature": "sig_test"},
|
||||
content=b"{}",
|
||||
)
|
||||
assert first_response.status_code == 200, first_response.text
|
||||
|
||||
updated_limit = await _get_pages_limit(TEST_EMAIL)
|
||||
assert updated_limit == 3100
|
||||
|
||||
purchase = await _fetchrow(
|
||||
"""
|
||||
SELECT status, amount_total, currency, stripe_payment_intent_id
|
||||
FROM page_purchases
|
||||
WHERE stripe_checkout_session_id = $1
|
||||
""",
|
||||
checkout_session.id,
|
||||
)
|
||||
assert purchase is not None
|
||||
assert purchase["status"] == "COMPLETED"
|
||||
assert purchase["amount_total"] == 300
|
||||
assert purchase["currency"] == "usd"
|
||||
assert purchase["stripe_payment_intent_id"] == "pi_test_123"
|
||||
|
||||
second_response = await client.post(
|
||||
"/api/v1/stripe/webhook",
|
||||
headers={"Stripe-Signature": "sig_test"},
|
||||
content=b"{}",
|
||||
)
|
||||
assert second_response.status_code == 200, second_response.text
|
||||
|
||||
assert await _get_pages_limit(TEST_EMAIL) == 3100
|
||||
|
||||
|
||||
class TestStripeReconciliation:
|
||||
async def test_reconciliation_fulfills_paid_pending_purchase(
|
||||
self,
|
||||
client,
|
||||
headers,
|
||||
search_space_id: int,
|
||||
page_limits,
|
||||
monkeypatch,
|
||||
):
|
||||
await page_limits.set(pages_used=220, pages_limit=150)
|
||||
|
||||
checkout_session = SimpleNamespace(
|
||||
id="cs_test_reconcile_paid_123",
|
||||
url="https://checkout.stripe.test/cs_test_reconcile_paid_123",
|
||||
payment_intent=None,
|
||||
amount_total=None,
|
||||
currency=None,
|
||||
)
|
||||
create_client = _FakeCreateStripeClient(checkout_session)
|
||||
|
||||
monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000")
|
||||
monkeypatch.setattr(
|
||||
stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000"
|
||||
)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000)
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/stripe/create-checkout-session",
|
||||
headers=headers,
|
||||
json={"quantity": 3, "search_space_id": search_space_id},
|
||||
)
|
||||
assert create_response.status_code == 200, create_response.text
|
||||
assert await _get_pages_limit(TEST_EMAIL) == 150
|
||||
|
||||
reconciled_session = SimpleNamespace(
|
||||
id=checkout_session.id,
|
||||
status="complete",
|
||||
payment_status="paid",
|
||||
payment_intent="pi_test_reconcile_123",
|
||||
amount_total=300,
|
||||
currency="usd",
|
||||
metadata={},
|
||||
)
|
||||
reconcile_client = _FakeReconciliationStripeClient(reconciled_session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task, "get_stripe_client", lambda: reconcile_client
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task.config,
|
||||
"STRIPE_RECONCILIATION_LOOKBACK_MINUTES",
|
||||
0,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task.config,
|
||||
"STRIPE_RECONCILIATION_BATCH_SIZE",
|
||||
20,
|
||||
)
|
||||
|
||||
await stripe_reconciliation_task._reconcile_pending_page_purchases()
|
||||
|
||||
assert reconcile_client.requested_ids == [checkout_session.id]
|
||||
assert await _get_pages_limit(TEST_EMAIL) == 3220
|
||||
|
||||
purchase = await _fetchrow(
|
||||
"""
|
||||
SELECT status, amount_total, currency, stripe_payment_intent_id
|
||||
FROM page_purchases
|
||||
WHERE stripe_checkout_session_id = $1
|
||||
""",
|
||||
checkout_session.id,
|
||||
)
|
||||
assert purchase is not None
|
||||
assert purchase["status"] == "COMPLETED"
|
||||
assert purchase["amount_total"] == 300
|
||||
assert purchase["currency"] == "usd"
|
||||
assert purchase["stripe_payment_intent_id"] == "pi_test_reconcile_123"
|
||||
|
||||
async def test_reconciliation_marks_expired_pending_purchase_failed(
|
||||
self,
|
||||
client,
|
||||
headers,
|
||||
search_space_id: int,
|
||||
page_limits,
|
||||
monkeypatch,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=500)
|
||||
|
||||
checkout_session = SimpleNamespace(
|
||||
id="cs_test_reconcile_expired_123",
|
||||
url="https://checkout.stripe.test/cs_test_reconcile_expired_123",
|
||||
payment_intent=None,
|
||||
amount_total=None,
|
||||
currency=None,
|
||||
)
|
||||
create_client = _FakeCreateStripeClient(checkout_session)
|
||||
|
||||
monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000")
|
||||
monkeypatch.setattr(
|
||||
stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000"
|
||||
)
|
||||
monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000)
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/stripe/create-checkout-session",
|
||||
headers=headers,
|
||||
json={"quantity": 1, "search_space_id": search_space_id},
|
||||
)
|
||||
assert create_response.status_code == 200, create_response.text
|
||||
|
||||
expired_session = SimpleNamespace(
|
||||
id=checkout_session.id,
|
||||
status="expired",
|
||||
payment_status="unpaid",
|
||||
payment_intent=None,
|
||||
amount_total=100,
|
||||
currency="usd",
|
||||
metadata={},
|
||||
)
|
||||
reconcile_client = _FakeReconciliationStripeClient(expired_session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task, "get_stripe_client", lambda: reconcile_client
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task.config,
|
||||
"STRIPE_RECONCILIATION_LOOKBACK_MINUTES",
|
||||
0,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
stripe_reconciliation_task.config,
|
||||
"STRIPE_RECONCILIATION_BATCH_SIZE",
|
||||
20,
|
||||
)
|
||||
|
||||
await stripe_reconciliation_task._reconcile_pending_page_purchases()
|
||||
|
||||
assert await _get_pages_limit(TEST_EMAIL) == 500
|
||||
|
||||
purchase = await _fetchrow(
|
||||
"""
|
||||
SELECT status
|
||||
FROM page_purchases
|
||||
WHERE stripe_checkout_session_id = $1
|
||||
""",
|
||||
checkout_session.id,
|
||||
)
|
||||
assert purchase is not None
|
||||
assert purchase["status"] == "FAILED"
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Integration tests: Dropbox ConnectorDocuments flow through the pipeline."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _dropbox_doc(
|
||||
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||
) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"File {unique_id}.docx",
|
||||
source_markdown=f"## Document\n\nContent from {unique_id}",
|
||||
unique_id=unique_id,
|
||||
document_type=DocumentType.DROPBOX_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=True,
|
||||
fallback_summary=f"File: {unique_id}.docx",
|
||||
metadata={
|
||||
"dropbox_file_id": unique_id,
|
||||
"dropbox_file_name": f"{unique_id}.docx",
|
||||
"document_type": "Dropbox File",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_dropbox_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A Dropbox ConnectorDocument flows through prepare + index to a READY document."""
|
||||
space_id = db_search_space.id
|
||||
doc = _dropbox_doc(
|
||||
unique_id="db-file-abc",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=str(db_user.id),
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row is not None
|
||||
assert row.document_type == DocumentType.DROPBOX_FILE
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_dropbox_duplicate_content_skipped(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""Re-indexing a Dropbox doc with the same content is skipped (content hash match)."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
|
||||
doc = _dropbox_doc(
|
||||
unique_id="db-dup-file",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
first_doc = result.scalars().first()
|
||||
assert first_doc is not None
|
||||
doc2 = _dropbox_doc(
|
||||
unique_id="db-dup-file",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
prepared2 = await service.prepare_for_indexing([doc2])
|
||||
assert len(prepared2) == 0 or (
|
||||
len(prepared2) == 1 and prepared2[0].existing_document is not None
|
||||
)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Integration tests: OneDrive ConnectorDocuments flow through the pipeline."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _onedrive_doc(
|
||||
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||
) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"File {unique_id}.docx",
|
||||
source_markdown=f"## Document\n\nContent from {unique_id}",
|
||||
unique_id=unique_id,
|
||||
document_type=DocumentType.ONEDRIVE_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=True,
|
||||
fallback_summary=f"File: {unique_id}.docx",
|
||||
metadata={
|
||||
"onedrive_file_id": unique_id,
|
||||
"onedrive_file_name": f"{unique_id}.docx",
|
||||
"document_type": "OneDrive File",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_onedrive_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""A OneDrive ConnectorDocument flows through prepare + index to a READY document."""
|
||||
space_id = db_search_space.id
|
||||
doc = _onedrive_doc(
|
||||
unique_id="od-file-abc",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=str(db_user.id),
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row is not None
|
||||
assert row.document_type == DocumentType.ONEDRIVE_FILE
|
||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_onedrive_duplicate_content_skipped(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
"""Re-indexing a OneDrive doc with the same content is skipped (content hash match)."""
|
||||
space_id = db_search_space.id
|
||||
user_id = str(db_user.id)
|
||||
|
||||
doc = _onedrive_doc(
|
||||
unique_id="od-dup-file",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
assert len(prepared) == 1
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == space_id)
|
||||
)
|
||||
first_doc = result.scalars().first()
|
||||
assert first_doc is not None
|
||||
doc2 = _onedrive_doc(
|
||||
unique_id="od-dup-file",
|
||||
search_space_id=space_id,
|
||||
connector_id=db_connector.id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
prepared2 = await service.prepare_for_indexing([doc2])
|
||||
assert len(prepared2) == 0 or (
|
||||
len(prepared2) == 1 and prepared2[0].existing_document is not None
|
||||
)
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -22,6 +22,7 @@ def _make_document(
|
|||
content: str,
|
||||
search_space_id: int,
|
||||
created_by_id: str,
|
||||
updated_at: datetime | None = None,
|
||||
) -> Document:
|
||||
uid = uuid.uuid4().hex[:12]
|
||||
return Document(
|
||||
|
|
@ -34,7 +35,7 @@ def _make_document(
|
|||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
updated_at=datetime.now(UTC),
|
||||
updated_at=updated_at or datetime.now(UTC),
|
||||
status={"state": "ready"},
|
||||
)
|
||||
|
||||
|
|
@ -104,3 +105,54 @@ async def seed_large_doc(
|
|||
"search_space": db_search_space,
|
||||
"user": db_user,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seed_date_filtered_docs(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
):
|
||||
"""Insert matching docs with different timestamps for date-filter tests."""
|
||||
user_id = str(db_user.id)
|
||||
space_id = db_search_space.id
|
||||
now = datetime.now(UTC)
|
||||
|
||||
recent_doc = _make_document(
|
||||
title="Recent OCV Notes",
|
||||
document_type=DocumentType.FILE,
|
||||
content="ocv meeting decisions and action items",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
updated_at=now,
|
||||
)
|
||||
old_doc = _make_document(
|
||||
title="Old OCV Notes",
|
||||
document_type=DocumentType.FILE,
|
||||
content="ocv meeting decisions and action items",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
updated_at=now - timedelta(days=730),
|
||||
)
|
||||
|
||||
db_session.add_all([recent_doc, old_doc])
|
||||
await db_session.flush()
|
||||
|
||||
db_session.add_all(
|
||||
[
|
||||
_make_chunk(
|
||||
content="ocv meeting decisions and action items recent",
|
||||
document_id=recent_doc.id,
|
||||
),
|
||||
_make_chunk(
|
||||
content="ocv meeting decisions and action items old",
|
||||
document_id=old_doc.id,
|
||||
),
|
||||
]
|
||||
)
|
||||
await db_session.flush()
|
||||
|
||||
return {
|
||||
"recent_doc": recent_doc,
|
||||
"old_doc": old_doc,
|
||||
"search_space": db_search_space,
|
||||
"user": db_user,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
"""Integration smoke tests for KB search query/date scoping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_search import search_knowledge_base
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_search_knowledge_base_applies_date_filters(
|
||||
db_session,
|
||||
seed_date_filtered_docs,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Date filters should remove older matching documents from scoped KB results."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_shielded_async_session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.shielded_async_session",
|
||||
fake_shielded_async_session,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.embed_texts",
|
||||
lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts],
|
||||
)
|
||||
|
||||
space_id = seed_date_filtered_docs["search_space"].id
|
||||
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
unfiltered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
)
|
||||
filtered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
start_date=recent_cutoff,
|
||||
end_date=datetime.now(UTC),
|
||||
)
|
||||
|
||||
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
|
||||
filtered_ids = {result["document"]["id"] for result in filtered_results}
|
||||
|
||||
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Tests for parallel download + indexing in the Dropbox indexer."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.dropbox_indexer import (
|
||||
_download_files_parallel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_file_dict(file_id: str, name: str) -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
".tag": "file",
|
||||
"path_lower": f"/{name}",
|
||||
"server_modified": "2026-01-01T00:00:00Z",
|
||||
"content_hash": f"hash_{file_id}",
|
||||
}
|
||||
|
||||
|
||||
def _mock_extract_ok(file_id: str, file_name: str):
|
||||
return (
|
||||
f"# Content of {file_name}",
|
||||
{"dropbox_file_id": file_id, "dropbox_file_name": file_name},
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dropbox_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract(monkeypatch):
|
||||
def _patch(side_effect=None, return_value=None):
|
||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.dropbox_indexer.download_and_extract_content",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
# Slice 1: Tracer bullet
|
||||
async def test_single_file_returns_one_connector_document(
|
||||
mock_dropbox_client,
|
||||
patch_extract,
|
||||
):
|
||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
[_make_file_dict("f1", "test.txt")],
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 0
|
||||
assert docs[0].title == "test.txt"
|
||||
assert docs[0].unique_id == "f1"
|
||||
assert docs[0].document_type == DocumentType.DROPBOX_FILE
|
||||
|
||||
|
||||
# Slice 2: Multiple files all produce documents
|
||||
async def test_multiple_files_all_produce_documents(
|
||||
mock_dropbox_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
|
||||
|
||||
|
||||
# Slice 3: Error isolation
|
||||
async def test_one_download_exception_does_not_block_others(
|
||||
mock_dropbox_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "file0.txt"),
|
||||
RuntimeError("network timeout"),
|
||||
_mock_extract_ok("f2", "file2.txt"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert failed == 1
|
||||
assert {d.unique_id for d in docs} == {"f0", "f2"}
|
||||
|
||||
|
||||
# Slice 4: ETL error counts as download failure
|
||||
async def test_etl_error_counts_as_download_failure(
|
||||
mock_dropbox_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "good.txt"),
|
||||
(None, {}, "ETL failed"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 1
|
||||
|
||||
|
||||
# Slice 5: Semaphore bound
|
||||
async def test_concurrency_bounded_by_semaphore(
|
||||
mock_dropbox_client,
|
||||
monkeypatch,
|
||||
):
|
||||
lock = asyncio.Lock()
|
||||
active = 0
|
||||
peak = 0
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
nonlocal active, peak
|
||||
async with lock:
|
||||
active += 1
|
||||
peak = max(peak, active)
|
||||
await asyncio.sleep(0.05)
|
||||
async with lock:
|
||||
active -= 1
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.dropbox_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
assert len(docs) == 6
|
||||
assert failed == 0
|
||||
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
|
||||
|
||||
|
||||
# Slice 6: Heartbeat fires
|
||||
async def test_heartbeat_fires_during_parallel_downloads(
|
||||
mock_dropbox_client,
|
||||
monkeypatch,
|
||||
):
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
await asyncio.sleep(0.05)
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.dropbox_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
heartbeat_calls: list[int] = []
|
||||
|
||||
async def _on_heartbeat(count: int):
|
||||
heartbeat_calls.append(count)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_dropbox_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
on_heartbeat=_on_heartbeat,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
"""Tests for parallel download + indexing in the OneDrive indexer."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.onedrive_indexer import (
|
||||
_download_files_parallel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"file": {"mimeType": mime},
|
||||
"lastModifiedDateTime": "2026-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _mock_extract_ok(file_id: str, file_name: str):
|
||||
return (
|
||||
f"# Content of {file_name}",
|
||||
{"onedrive_file_id": file_id, "onedrive_file_name": file_name},
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_onedrive_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract(monkeypatch):
|
||||
def _patch(side_effect=None, return_value=None):
|
||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.onedrive_indexer.download_and_extract_content",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
# Slice 1: Tracer bullet
|
||||
async def test_single_file_returns_one_connector_document(
|
||||
mock_onedrive_client,
|
||||
patch_extract,
|
||||
):
|
||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
[_make_file_dict("f1", "test.txt")],
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 0
|
||||
assert docs[0].title == "test.txt"
|
||||
assert docs[0].unique_id == "f1"
|
||||
assert docs[0].document_type == DocumentType.ONEDRIVE_FILE
|
||||
|
||||
|
||||
# Slice 2: Multiple files all produce documents
|
||||
async def test_multiple_files_all_produce_documents(
|
||||
mock_onedrive_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
|
||||
|
||||
|
||||
# Slice 3: Error isolation
|
||||
async def test_one_download_exception_does_not_block_others(
|
||||
mock_onedrive_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "file0.txt"),
|
||||
RuntimeError("network timeout"),
|
||||
_mock_extract_ok("f2", "file2.txt"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert failed == 1
|
||||
assert {d.unique_id for d in docs} == {"f0", "f2"}
|
||||
|
||||
|
||||
# Slice 4: ETL error counts as download failure
|
||||
async def test_etl_error_counts_as_download_failure(
|
||||
mock_onedrive_client,
|
||||
patch_extract,
|
||||
):
|
||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "good.txt"),
|
||||
(None, {}, "ETL failed"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 1
|
||||
|
||||
|
||||
# Slice 5: Semaphore bound
|
||||
async def test_concurrency_bounded_by_semaphore(
|
||||
mock_onedrive_client,
|
||||
monkeypatch,
|
||||
):
|
||||
lock = asyncio.Lock()
|
||||
active = 0
|
||||
peak = 0
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
nonlocal active, peak
|
||||
async with lock:
|
||||
active += 1
|
||||
peak = max(peak, active)
|
||||
await asyncio.sleep(0.05)
|
||||
async with lock:
|
||||
active -= 1
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.onedrive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
assert len(docs) == 6
|
||||
assert failed == 0
|
||||
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
|
||||
|
||||
|
||||
# Slice 6: Heartbeat fires
|
||||
async def test_heartbeat_fires_during_parallel_downloads(
|
||||
mock_onedrive_client,
|
||||
monkeypatch,
|
||||
):
|
||||
import app.tasks.connector_indexers.onedrive_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
await asyncio.sleep(0.05)
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.onedrive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
heartbeat_calls: list[int] = []
|
||||
|
||||
async def _on_heartbeat(count: int):
|
||||
heartbeat_calls.append(count)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_onedrive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
on_heartbeat=_on_heartbeat,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
|
@ -1,12 +1,16 @@
|
|||
"""Unit tests for knowledge_search middleware helpers.
|
||||
"""Unit tests for knowledge_search middleware helpers."""
|
||||
|
||||
These test pure functions that don't require a database.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
_build_document_xml,
|
||||
_normalize_optional_date_range,
|
||||
_parse_kb_search_plan_response,
|
||||
_render_recent_conversation,
|
||||
_resolve_search_types,
|
||||
)
|
||||
|
||||
|
|
@ -131,3 +135,234 @@ class TestBuildDocumentXml:
|
|||
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||
]
|
||||
assert len(chunk_lines) == 3
|
||||
|
||||
|
||||
# ── planner parsing / date normalization ───────────────────────────────
|
||||
|
||||
|
||||
class TestPlannerHelpers:
|
||||
def test_parse_kb_search_plan_response_accepts_plain_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions summary",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.optimized_query == "ocv meeting decisions summary"
|
||||
assert plan.start_date == "2026-03-01"
|
||||
assert plan.end_date == "2026-03-31"
|
||||
|
||||
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
"""```json
|
||||
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
|
||||
```"""
|
||||
)
|
||||
assert plan.optimized_query == "deel founders guide"
|
||||
assert plan.start_date is None
|
||||
assert plan.end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_returns_none_when_absent(self):
|
||||
start_date, end_date = _normalize_optional_date_range(None, None)
|
||||
assert start_date is None
|
||||
assert end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_resolves_single_bound(self):
|
||||
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
|
||||
assert start_date is not None
|
||||
assert end_date is not None
|
||||
assert start_date.date().isoformat() == "2026-03-01"
|
||||
assert end_date >= start_date
|
||||
|
||||
|
||||
class FakeLLM:
|
||||
def __init__(self, response_text: str):
|
||||
self.response_text = response_text
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def ainvoke(self, messages, config=None):
|
||||
self.calls.append({"messages": messages, "config": config})
|
||||
return AIMessage(content=self.response_text)
|
||||
|
||||
|
||||
class FakeBudgetLLM:
|
||||
def __init__(self, *, max_input_tokens: int):
|
||||
self._max_input_tokens_value = max_input_tokens
|
||||
|
||||
def _get_max_input_tokens(self) -> int:
|
||||
return self._max_input_tokens_value
|
||||
|
||||
def _count_tokens(self, messages) -> int:
|
||||
# Deterministic, simple proxy for tests: count characters as tokens.
|
||||
return sum(len(msg.get("content", "")) for msg in messages)
|
||||
|
||||
|
||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
AIMessage(content="old assistant answer " * 35),
|
||||
HumanMessage(content="recent user context " * 20),
|
||||
AIMessage(content="recent assistant answer " * 18),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=FakeBudgetLLM(max_input_tokens=900),
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "recent user context" in rendered
|
||||
assert "recent assistant answer" in rendered
|
||||
assert "latest question" not in rendered
|
||||
assert rendered.index("recent user context") < rendered.index(
|
||||
"recent assistant answer"
|
||||
)
|
||||
|
||||
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
|
||||
messages = [
|
||||
HumanMessage(content="message one"),
|
||||
AIMessage(content="message two"),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=None,
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "user: message one" in rendered
|
||||
assert "assistant: message two" in rendered
|
||||
assert "latest question" not in rendered
|
||||
|
||||
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions action items",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="what happened in our OCV meeting last month?")
|
||||
]
|
||||
},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert captured["query"] == "ocv meeting decisions action items"
|
||||
assert captured["start_date"] is not None
|
||||
assert captured["end_date"] is not None
|
||||
assert captured["start_date"].date().isoformat() == "2026-03-01"
|
||||
assert captured["end_date"].date().isoformat() == "2026-03-31"
|
||||
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
|
||||
|
||||
async def test_middleware_falls_back_when_planner_returns_invalid_json(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=FakeLLM("not json"),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "summarize founders guide by deel"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "deel founders guide summary",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "deel founders guide summary"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
|
|
|||
17
surfsense_backend/uv.lock
generated
17
surfsense_backend/uv.lock
generated
|
|
@ -7918,9 +7918,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a9/d3/c238124fbf2dbe5eda203f0a1c4cd6c210e27993ed9780c4c1bf2ab0efbe/static_ffmpeg-3.0-py3-none-any.whl", hash = "sha256:79d9067264cefbb05e6b847be7d6cb7410b7b25adce40178a787f0137567c89f", size = 7927 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stripe"
|
||||
version = "15.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/36/5a/0cdea4b7911b8012936c765544109da27c0728f6911ec7aefe9d59e7a4f9/stripe-15.0.0.tar.gz", hash = "sha256:0717cd9ba8e8193cef8b1c488ce27836754df496ab6fb75864096e0cdf15e52d", size = 1486873 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/4a/4431c998c451cf07f8b4fed98f425b4aaf3d59cc4fb1e6f54d7713606688/stripe-15.0.0-py3-none-any.whl", hash = "sha256:434ec5267a7402a30b76786d159c18d0e138f89195969d6c56bea2e08d353be0", size = 2125454 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "surf-new-backend"
|
||||
version = "0.0.13"
|
||||
version = "0.0.14"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
|
|
@ -7984,6 +7997,7 @@ dependencies = [
|
|||
{ name = "sse-starlette" },
|
||||
{ name = "starlette" },
|
||||
{ name = "static-ffmpeg" },
|
||||
{ name = "stripe" },
|
||||
{ name = "tavily-python" },
|
||||
{ name = "tornado" },
|
||||
{ name = "trafilatura" },
|
||||
|
|
@ -8067,6 +8081,7 @@ requires-dist = [
|
|||
{ name = "sse-starlette", specifier = ">=3.1.1,<3.1.2" },
|
||||
{ name = "starlette", specifier = ">=0.40.0,<0.51.0" },
|
||||
{ name = "static-ffmpeg", specifier = ">=2.13" },
|
||||
{ name = "stripe", specifier = ">=15.0.0" },
|
||||
{ name = "tavily-python", specifier = ">=0.3.2" },
|
||||
{ name = "tornado", specifier = ">=6.5.5" },
|
||||
{ name = "trafilatura", specifier = ">=2.0.0" },
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"name": "surfsense_browser_extension",
|
||||
"displayName": "Surfsense Browser Extension",
|
||||
"version": "0.0.13",
|
||||
"version": "0.0.14",
|
||||
"description": "Extension to collect Browsing History for SurfSense.",
|
||||
"author": "https://github.com/MODSetter",
|
||||
"engines": {
|
||||
|
|
|
|||
|
|
@ -3,4 +3,7 @@ export const IPC_CHANNELS = {
|
|||
GET_APP_VERSION: 'get-app-version',
|
||||
DEEP_LINK: 'deep-link',
|
||||
QUICK_ASK_TEXT: 'quick-ask-text',
|
||||
SET_QUICK_ASK_MODE: 'set-quick-ask-mode',
|
||||
GET_QUICK_ASK_MODE: 'get-quick-ask-mode',
|
||||
REPLACE_TEXT: 'replace-text',
|
||||
} as const;
|
||||
|
|
|
|||
55
surfsense_desktop/src/modules/platform.ts
Normal file
55
surfsense_desktop/src/modules/platform.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import { execSync } from 'child_process';
|
||||
import { systemPreferences } from 'electron';
|
||||
|
||||
export function getFrontmostApp(): string {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get name of first application process whose frontmost is true\''
|
||||
).toString().trim();
|
||||
}
|
||||
if (process.platform === 'win32') {
|
||||
return execSync(
|
||||
'powershell -command "Add-Type \'using System; using System.Runtime.InteropServices; public class W { [DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow(); }\'; (Get-Process | Where-Object { $_.MainWindowHandle -eq [W]::GetForegroundWindow() }).ProcessName"'
|
||||
).toString().trim();
|
||||
}
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
export function getSelectedText(): string {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get value of attribute "AXSelectedText" of focused UI element of first application process whose frontmost is true\''
|
||||
).toString().trim();
|
||||
}
|
||||
// Windows: no reliable accessibility API for selected text across apps
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
export function simulateCopy(): void {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "c" using command down\'');
|
||||
} else if (process.platform === 'win32') {
|
||||
execSync('powershell -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^c\')"');
|
||||
}
|
||||
}
|
||||
|
||||
export function simulatePaste(): void {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'');
|
||||
} else if (process.platform === 'win32') {
|
||||
execSync('powershell -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^v\')"');
|
||||
}
|
||||
}
|
||||
|
||||
export function checkAccessibilityPermission(): boolean {
|
||||
if (process.platform !== 'darwin') return true;
|
||||
return systemPreferences.isTrustedAccessibilityClient(true);
|
||||
}
|
||||
|
|
@ -1,16 +1,22 @@
|
|||
import { BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron';
|
||||
import path from 'path';
|
||||
import { IPC_CHANNELS } from '../ipc/channels';
|
||||
import { checkAccessibilityPermission, getFrontmostApp, simulatePaste } from './platform';
|
||||
import { getServerPort } from './server';
|
||||
|
||||
const SHORTCUT = 'CommandOrControl+Option+S';
|
||||
let quickAskWindow: BrowserWindow | null = null;
|
||||
let pendingText = '';
|
||||
let pendingMode = '';
|
||||
let sourceApp = '';
|
||||
let savedClipboard = '';
|
||||
|
||||
function hideQuickAsk(): void {
|
||||
function destroyQuickAsk(): void {
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
|
||||
quickAskWindow.hide();
|
||||
quickAskWindow.close();
|
||||
}
|
||||
quickAskWindow = null;
|
||||
pendingMode = '';
|
||||
}
|
||||
|
||||
function clampToScreen(x: number, y: number, w: number, h: number): { x: number; y: number } {
|
||||
|
|
@ -23,16 +29,11 @@ function clampToScreen(x: number, y: number, w: number, h: number): { x: number;
|
|||
}
|
||||
|
||||
function createQuickAskWindow(x: number, y: number): BrowserWindow {
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
|
||||
quickAskWindow.setPosition(x, y);
|
||||
quickAskWindow.show();
|
||||
quickAskWindow.focus();
|
||||
return quickAskWindow;
|
||||
}
|
||||
destroyQuickAsk();
|
||||
|
||||
quickAskWindow = new BrowserWindow({
|
||||
width: 450,
|
||||
height: 550,
|
||||
height: 750,
|
||||
x,
|
||||
y,
|
||||
...(process.platform === 'darwin'
|
||||
|
|
@ -58,7 +59,7 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
|
|||
});
|
||||
|
||||
quickAskWindow.webContents.on('before-input-event', (_event, input) => {
|
||||
if (input.key === 'Escape') hideQuickAsk();
|
||||
if (input.key === 'Escape') destroyQuickAsk();
|
||||
});
|
||||
|
||||
quickAskWindow.webContents.setWindowOpenHandler(({ url }) => {
|
||||
|
|
@ -78,17 +79,20 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
|
|||
|
||||
export function registerQuickAsk(): void {
|
||||
const ok = globalShortcut.register(SHORTCUT, () => {
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed() && quickAskWindow.isVisible()) {
|
||||
hideQuickAsk();
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
|
||||
destroyQuickAsk();
|
||||
return;
|
||||
}
|
||||
|
||||
const text = clipboard.readText().trim();
|
||||
sourceApp = getFrontmostApp();
|
||||
savedClipboard = clipboard.readText();
|
||||
|
||||
const text = savedClipboard.trim();
|
||||
if (!text) return;
|
||||
|
||||
pendingText = text;
|
||||
const cursor = screen.getCursorScreenPoint();
|
||||
const pos = clampToScreen(cursor.x, cursor.y, 450, 550);
|
||||
const pos = clampToScreen(cursor.x, cursor.y, 450, 750);
|
||||
createQuickAskWindow(pos.x, pos.y);
|
||||
});
|
||||
|
||||
|
|
@ -101,6 +105,35 @@ export function registerQuickAsk(): void {
|
|||
pendingText = '';
|
||||
return text;
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_QUICK_ASK_MODE, (_event, mode: string) => {
|
||||
pendingMode = mode;
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_QUICK_ASK_MODE, (event) => {
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed() && event.sender.id === quickAskWindow.webContents.id) {
|
||||
return pendingMode;
|
||||
}
|
||||
return '';
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.REPLACE_TEXT, async (_event, text: string) => {
|
||||
if (!sourceApp) return;
|
||||
|
||||
if (!checkAccessibilityPermission()) return;
|
||||
|
||||
clipboard.writeText(text);
|
||||
destroyQuickAsk();
|
||||
|
||||
try {
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
simulatePaste();
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
clipboard.writeText(savedClipboard);
|
||||
} catch {
|
||||
clipboard.writeText(savedClipboard);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export function unregisterQuickAsk(): void {
|
||||
|
|
|
|||
|
|
@ -18,4 +18,7 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
};
|
||||
},
|
||||
getQuickAskText: () => ipcRenderer.invoke(IPC_CHANNELS.QUICK_ASK_TEXT),
|
||||
setQuickAskMode: (mode: string) => ipcRenderer.invoke(IPC_CHANNELS.SET_QUICK_ASK_MODE, mode),
|
||||
getQuickAskMode: () => ipcRenderer.invoke(IPC_CHANNELS.GET_QUICK_ASK_MODE),
|
||||
replaceText: (text: string) => ipcRenderer.invoke(IPC_CHANNELS.REPLACE_TEXT, text),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ export function LocalLoginForm() {
|
|||
animate={{ opacity: 1, y: 0, scale: 1 }}
|
||||
exit={{ opacity: 0, y: -10, scale: 0.95 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
className="rounded-lg border border-red-200 bg-red-50 p-4 text-red-900 shadow-sm dark:border-red-900/30 dark:bg-red-900/20 dark:text-red-200"
|
||||
className="rounded-lg border border-destructive/20 bg-destructive/10 p-4 text-destructive shadow-sm"
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<svg
|
||||
|
|
@ -109,7 +109,7 @@ export function LocalLoginForm() {
|
|||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
className="flex-shrink-0 mt-0.5 text-red-500 dark:text-red-400"
|
||||
className="flex-shrink-0 mt-0.5 text-destructive"
|
||||
>
|
||||
<title>Error Icon</title>
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
|
|
@ -118,13 +118,13 @@ export function LocalLoginForm() {
|
|||
</svg>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-sm font-semibold mb-1">{error.title}</p>
|
||||
<p className="text-sm text-red-700 dark:text-red-300">{error.message}</p>
|
||||
<p className="text-sm text-destructive">{error.message}</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => {
|
||||
setError({ title: null, message: null });
|
||||
}}
|
||||
className="flex-shrink-0 text-red-500 hover:text-red-700 dark:text-red-400 dark:hover:text-red-200 transition-colors"
|
||||
className="flex-shrink-0 text-destructive hover:text-destructive/90 transition-colors"
|
||||
aria-label="Dismiss error"
|
||||
type="button"
|
||||
>
|
||||
|
|
@ -150,10 +150,7 @@ export function LocalLoginForm() {
|
|||
</AnimatePresence>
|
||||
|
||||
<div>
|
||||
<label
|
||||
htmlFor="email"
|
||||
className="block text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<label htmlFor="email" className="block text-sm font-medium text-foreground">
|
||||
{t("email")}
|
||||
</label>
|
||||
<input
|
||||
|
|
@ -163,20 +160,17 @@ export function LocalLoginForm() {
|
|||
placeholder="you@example.com"
|
||||
value={username}
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
className={`mt-1 block w-full rounded-md border px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-2 focus:ring-offset-2 dark:bg-gray-800 dark:text-white transition-all ${
|
||||
className={`mt-1 block w-full rounded-md border px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-2 focus:ring-offset-2 bg-background text-foreground transition-all ${
|
||||
error.title
|
||||
? "border-red-300 focus:border-red-500 focus:ring-red-500 dark:border-red-700"
|
||||
: "border-gray-300 focus:border-blue-500 focus:ring-blue-500 dark:border-gray-700"
|
||||
? "border-destructive focus:border-destructive focus:ring-destructive"
|
||||
: "border-border focus:border-primary focus:ring-primary"
|
||||
}`}
|
||||
disabled={isLoggingIn}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label
|
||||
htmlFor="password"
|
||||
className="block text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<label htmlFor="password" className="block text-sm font-medium text-foreground">
|
||||
{t("password")}
|
||||
</label>
|
||||
<div className="relative">
|
||||
|
|
@ -187,17 +181,17 @@ export function LocalLoginForm() {
|
|||
placeholder="Enter your password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
className={`mt-1 block w-full rounded-md border pr-10 px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-2 focus:ring-offset-2 dark:bg-gray-800 dark:text-white transition-all ${
|
||||
className={`mt-1 block w-full rounded-md border pr-10 px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-2 focus:ring-offset-2 bg-background text-foreground transition-all ${
|
||||
error.title
|
||||
? "border-red-300 focus:border-red-500 focus:ring-red-500 dark:border-red-700"
|
||||
: "border-gray-300 focus:border-blue-500 focus:ring-blue-500 dark:border-gray-700"
|
||||
? "border-destructive focus:border-destructive focus:ring-destructive"
|
||||
: "border-border focus:border-primary focus:ring-primary"
|
||||
}`}
|
||||
disabled={isLoggingIn}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowPassword((prev) => !prev)}
|
||||
className="absolute inset-y-0 right-0 flex items-center pr-3 mt-1 text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200"
|
||||
className="absolute inset-y-0 right-0 flex items-center pr-3 mt-1 text-muted-foreground hover:text-foreground"
|
||||
aria-label={showPassword ? t("hide_password") : t("show_password")}
|
||||
>
|
||||
{showPassword ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||
|
|
@ -208,12 +202,12 @@ export function LocalLoginForm() {
|
|||
<button
|
||||
type="submit"
|
||||
disabled={isLoggingIn}
|
||||
className="relative w-full rounded-md bg-blue-600 px-4 py-1.5 md:py-2 text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 transition-all text-sm md:text-base flex items-center justify-center gap-2"
|
||||
className="relative w-full rounded-md bg-primary px-4 py-1.5 md:py-2 text-primary-foreground shadow-sm hover:bg-primary/90 focus:outline-none focus:ring-2 focus:ring-primary focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 transition-all text-sm md:text-base flex items-center justify-center gap-2"
|
||||
>
|
||||
<span className={isLoggingIn ? "invisible" : ""}>{t("sign_in")}</span>
|
||||
{isLoggingIn && (
|
||||
<span className="absolute inset-0 flex items-center justify-center">
|
||||
<Spinner size="sm" className="text-white" />
|
||||
<Spinner size="sm" className="text-primary-foreground" />
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
|
|
@ -221,12 +215,9 @@ export function LocalLoginForm() {
|
|||
|
||||
{authType === "LOCAL" && (
|
||||
<div className="mt-4 text-center text-sm">
|
||||
<p className="text-gray-600 dark:text-gray-400">
|
||||
<p className="text-muted-foreground">
|
||||
{t("dont_have_account")}{" "}
|
||||
<Link
|
||||
href="/register"
|
||||
className="font-medium text-blue-600 hover:text-blue-500 dark:text-blue-400"
|
||||
>
|
||||
<Link href="/register" className="font-medium text-primary hover:text-primary/90">
|
||||
{t("sign_up")}
|
||||
</Link>
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import dynamic from "next/dynamic";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { HeroSection } from "@/components/homepage/hero-section";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
|
||||
const FeaturesCards = dynamic(
|
||||
() => import("@/components/homepage/features-card").then((m) => ({ default: m.FeaturesCards })),
|
||||
|
|
@ -26,6 +29,14 @@ const CTAHomepage = dynamic(
|
|||
);
|
||||
|
||||
export default function HomePage() {
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (getBearerToken()) {
|
||||
router.replace("/dashboard");
|
||||
}
|
||||
}, [router]);
|
||||
|
||||
return (
|
||||
<main className="min-h-screen bg-gradient-to-b from-gray-50 to-gray-100 text-gray-900 dark:from-black dark:to-gray-900 dark:text-white">
|
||||
<HeroSection />
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import { registerMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
|
|||
import { Logo } from "@/components/Logo";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { AUTH_TYPE } from "@/lib/env-config";
|
||||
import { AppError, ValidationError } from "@/lib/error";
|
||||
import {
|
||||
|
|
@ -38,6 +39,10 @@ export default function RegisterPage() {
|
|||
|
||||
// Check authentication type and redirect if not LOCAL
|
||||
useEffect(() => {
|
||||
if (getBearerToken()) {
|
||||
router.replace("/dashboard");
|
||||
return;
|
||||
}
|
||||
if (AUTH_TYPE !== "LOCAL") {
|
||||
router.push("/login");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
"use client";
|
||||
|
||||
import { motion } from "motion/react";
|
||||
import { BuyPagesContent } from "@/components/settings/buy-pages-content";
|
||||
|
||||
export default function BuyPagesPage() {
|
||||
return (
|
||||
<div className="flex min-h-[calc(100vh-64px)] select-none items-center justify-center px-4 py-8">
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
className="w-full max-w-md space-y-6"
|
||||
>
|
||||
<BuyPagesContent />
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -183,6 +183,10 @@ export function DashboardClientLayout({
|
|||
);
|
||||
}
|
||||
|
||||
if (isOnboardingPage) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
return (
|
||||
<DocumentUploadDialogProvider>
|
||||
<OnboardingTour />
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ export function getDocumentTypeLabel(type: string): string {
|
|||
FILE: "File",
|
||||
SLACK_CONNECTOR: "Slack",
|
||||
TEAMS_CONNECTOR: "Microsoft Teams",
|
||||
ONEDRIVE_FILE: "OneDrive",
|
||||
DROPBOX_FILE: "Dropbox",
|
||||
NOTION_CONNECTOR: "Notion",
|
||||
YOUTUBE_VIDEO: "YouTube Video",
|
||||
GITHUB_CONNECTOR: "GitHub",
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ import {
|
|||
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from "@/components/ui/dialog";
|
||||
import {
|
||||
Drawer,
|
||||
DrawerContent,
|
||||
|
|
@ -234,6 +233,7 @@ export function DocumentsTableShell({
|
|||
mentionedDocIds,
|
||||
onToggleChatMention,
|
||||
isSearchMode = false,
|
||||
onOpenInTab,
|
||||
}: {
|
||||
documents: Document[];
|
||||
loading: boolean;
|
||||
|
|
@ -253,6 +253,8 @@ export function DocumentsTableShell({
|
|||
onToggleChatMention?: (doc: Document, mentioned: boolean) => void;
|
||||
/** Whether results are filtered by a search query or type filters */
|
||||
isSearchMode?: boolean;
|
||||
/** When provided, desktop "Preview" opens a document tab instead of the popup dialog */
|
||||
onOpenInTab?: (doc: Document) => void;
|
||||
}) {
|
||||
const t = useTranslations("documents");
|
||||
const { openDialog } = useDocumentUploadDialog();
|
||||
|
|
@ -742,9 +744,13 @@ export function DocumentsTableShell({
|
|||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-48">
|
||||
<DropdownMenuItem onClick={() => handleViewDocument(doc)}>
|
||||
<DropdownMenuItem
|
||||
onClick={() =>
|
||||
onOpenInTab ? onOpenInTab(doc) : handleViewDocument(doc)
|
||||
}
|
||||
>
|
||||
<Eye className="h-4 w-4" />
|
||||
Preview
|
||||
Open
|
||||
</DropdownMenuItem>
|
||||
{isEditable && (
|
||||
<DropdownMenuItem
|
||||
|
|
@ -923,26 +929,18 @@ export function DocumentsTableShell({
|
|||
</div>
|
||||
)}
|
||||
|
||||
{/* Document Content Viewer */}
|
||||
<Dialog open={!!viewingDoc} onOpenChange={(open) => !open && handleCloseViewer()}>
|
||||
<DialogContent className="max-w-4xl max-w-[92%] md:max-w-4xl max-h-[75vh] md:max-h-[80vh] flex flex-col overflow-hidden pb-0 p-3 md:p-6 gap-2 md:gap-4">
|
||||
<DialogHeader className="flex-shrink-0">
|
||||
<DialogTitle className="text-sm md:text-lg leading-tight pr-6">
|
||||
{/* Document Content Viewer (mobile drawer) */}
|
||||
<Drawer open={!!viewingDoc} onOpenChange={(open) => !open && handleCloseViewer()}>
|
||||
<DrawerContent className="max-h-[85vh] flex flex-col">
|
||||
<DrawerHandle />
|
||||
<DrawerHeader className="text-left shrink-0">
|
||||
<DrawerTitle className="text-base leading-tight break-words">
|
||||
{viewingDoc?.title}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
</DrawerTitle>
|
||||
</DrawerHeader>
|
||||
<div
|
||||
onScroll={handlePreviewScroll}
|
||||
className={[
|
||||
"overflow-y-auto flex-1 min-h-0 px-1 md:px-6 select-text",
|
||||
"max-md:text-xs",
|
||||
"max-md:[&_h1]:text-base! max-md:[&_h1]:mt-3!",
|
||||
"max-md:[&_h2]:text-sm! max-md:[&_h2]:mt-2!",
|
||||
"max-md:[&_h3]:text-xs! max-md:[&_h3]:mt-2!",
|
||||
"max-md:[&_h4]:text-xs!",
|
||||
"max-md:[&_td]:text-[11px]! max-md:[&_td]:px-2! max-md:[&_td]:py-1.5!",
|
||||
"max-md:[&_th]:text-[11px]! max-md:[&_th]:px-2! max-md:[&_th]:py-1.5!",
|
||||
].join(" ")}
|
||||
className="overflow-y-auto flex-1 min-h-0 px-4 pb-6 select-text text-xs [&_h1]:text-base! [&_h1]:mt-3! [&_h2]:text-sm! [&_h2]:mt-2! [&_h3]:text-xs! [&_h3]:mt-2! [&_h4]:text-xs! [&_td]:text-[11px]! [&_td]:px-2! [&_td]:py-1.5! [&_th]:text-[11px]! [&_th]:px-2! [&_th]:py-1.5!"
|
||||
style={{
|
||||
maskImage: `linear-gradient(to bottom, ${previewScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${previewScrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
WebkitMaskImage: `linear-gradient(to bottom, ${previewScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${previewScrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
|
|
@ -956,8 +954,8 @@ export function DocumentsTableShell({
|
|||
<MarkdownViewer content={viewingContent} />
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</DrawerContent>
|
||||
</Drawer>
|
||||
|
||||
{/* Document Metadata Viewer (Ctrl+Click) */}
|
||||
<JsonMetadataViewer
|
||||
|
|
@ -992,9 +990,10 @@ export function DocumentsTableShell({
|
|||
handleDeleteFromMenu();
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? <Spinner size="sm" /> : "Delete"}
|
||||
<span className={isDeleting ? "opacity-0" : ""}>Delete</span>
|
||||
{isDeleting && <Spinner size="sm" className="absolute" />}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
|
|
@ -1027,7 +1026,7 @@ export function DocumentsTableShell({
|
|||
}}
|
||||
>
|
||||
<Eye className="h-4 w-4" />
|
||||
Preview
|
||||
Open
|
||||
</Button>
|
||||
{mobileActionDoc &&
|
||||
EDITABLE_DOCUMENT_TYPES.includes(
|
||||
|
|
@ -1110,9 +1109,10 @@ export function DocumentsTableShell({
|
|||
handleBulkDelete();
|
||||
}}
|
||||
disabled={isBulkDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isBulkDeleting ? <Spinner size="sm" /> : "Delete"}
|
||||
<span className={isBulkDeleting ? "opacity-0" : ""}>Delete</span>
|
||||
{isBulkDeleting && <Spinner size="sm" className="absolute" />}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import type { Document } from "./types";
|
||||
|
||||
const EDITABLE_DOCUMENT_TYPES = ["NOTE"] as const;
|
||||
const EDITABLE_DOCUMENT_TYPES = ["FILE", "NOTE"] as const;
|
||||
|
||||
// SURFSENSE_DOCS are system-managed and cannot be deleted
|
||||
const NON_DELETABLE_DOCUMENT_TYPES = ["SURFSENSE_DOCS"] as const;
|
||||
|
|
|
|||
|
|
@ -33,14 +33,13 @@ import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
|
|||
import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
|
||||
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||
import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
||||
import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
|
||||
import { Thread } from "@/components/assistant-ui/thread";
|
||||
import { MobileEditorPanel } from "@/components/editor-panel/editor-panel";
|
||||
import { MobileHitlEditPanel } from "@/components/hitl-edit-panel/hitl-edit-panel";
|
||||
import { MobileReportPanel } from "@/components/report-panel/report-panel";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
||||
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
|
|
@ -57,6 +56,7 @@ import {
|
|||
buildContentForPersistence,
|
||||
buildContentForUI,
|
||||
type ContentPartsState,
|
||||
FrameBatchedUpdater,
|
||||
readSSEStream,
|
||||
type ThinkingStepData,
|
||||
updateThinkingSteps,
|
||||
|
|
@ -70,6 +70,7 @@ import {
|
|||
getThreadMessages,
|
||||
type ThreadRecord,
|
||||
} from "@/lib/chat/thread-persistence";
|
||||
import { NotFoundError } from "@/lib/error";
|
||||
import {
|
||||
trackChatCreated,
|
||||
trackChatError,
|
||||
|
|
@ -131,6 +132,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
|
|||
* Tools that should render custom UI in the chat.
|
||||
*/
|
||||
const TOOLS_WITH_UI = new Set([
|
||||
"web_search",
|
||||
"generate_podcast",
|
||||
"generate_report",
|
||||
"generate_video_presentation",
|
||||
|
|
@ -144,6 +146,10 @@ const TOOLS_WITH_UI = new Set([
|
|||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
"create_onedrive_file",
|
||||
"delete_onedrive_file",
|
||||
"create_dropbox_file",
|
||||
"delete_dropbox_file",
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
|
|
@ -192,6 +198,7 @@ export default function NewChatPage() {
|
|||
const closeReportPanel = useSetAtom(closeReportPanelAtom);
|
||||
const closeEditorPanel = useSetAtom(closeEditorPanelAtom);
|
||||
const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom);
|
||||
const removeChatTab = useSetAtom(removeChatTabAtom);
|
||||
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
||||
|
||||
// Get current user for author info in shared chats
|
||||
|
|
@ -271,7 +278,6 @@ export default function NewChatPage() {
|
|||
|
||||
// Initialize thread and load messages
|
||||
// For new chats (no urlChatId), we use lazy creation - thread is created on first message
|
||||
// biome-ignore lint/correctness/useExhaustiveDependencies: searchSpaceId triggers re-init when switching spaces with the same urlChatId
|
||||
const initializeThread = useCallback(async () => {
|
||||
setIsInitializing(true);
|
||||
|
||||
|
|
@ -322,6 +328,14 @@ export default function NewChatPage() {
|
|||
// This improves UX (instant load) and avoids orphan threads
|
||||
} catch (error) {
|
||||
console.error("[NewChatPage] Failed to initialize thread:", error);
|
||||
if (urlChatId > 0 && error instanceof NotFoundError) {
|
||||
removeChatTab(urlChatId);
|
||||
if (typeof window !== "undefined") {
|
||||
window.history.replaceState(null, "", `/dashboard/${searchSpaceId}/new-chat`);
|
||||
}
|
||||
toast.error("This chat was deleted.");
|
||||
return;
|
||||
}
|
||||
// Keep threadId as null - don't use Date.now() as it creates an invalid ID
|
||||
// that will cause 404 errors on subsequent API calls
|
||||
setThreadId(null);
|
||||
|
|
@ -332,15 +346,16 @@ export default function NewChatPage() {
|
|||
}
|
||||
}, [
|
||||
urlChatId,
|
||||
searchSpaceId,
|
||||
setMessageDocumentsMap,
|
||||
setMentionedDocuments,
|
||||
setSidebarDocuments,
|
||||
closeReportPanel,
|
||||
closeEditorPanel,
|
||||
removeChatTab,
|
||||
searchSpaceId,
|
||||
]);
|
||||
|
||||
// Initialize on mount
|
||||
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
|
||||
useEffect(() => {
|
||||
initializeThread();
|
||||
}, [initializeThread]);
|
||||
|
|
@ -483,18 +498,17 @@ export default function NewChatPage() {
|
|||
// Add user message to state
|
||||
const userMsgId = `msg-user-${Date.now()}`;
|
||||
|
||||
// Include author metadata for shared chats
|
||||
const authorMetadata =
|
||||
currentThread?.visibility === "SEARCH_SPACE" && currentUser
|
||||
? {
|
||||
custom: {
|
||||
author: {
|
||||
displayName: currentUser.display_name ?? null,
|
||||
avatarUrl: currentUser.avatar_url ?? null,
|
||||
},
|
||||
// Always include author metadata so the UI layer can decide visibility
|
||||
const authorMetadata = currentUser
|
||||
? {
|
||||
custom: {
|
||||
author: {
|
||||
displayName: currentUser.display_name ?? null,
|
||||
avatarUrl: currentUser.avatar_url ?? null,
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const userMessage: ThreadMessageLike = {
|
||||
id: userMsgId,
|
||||
|
|
@ -570,6 +584,7 @@ export default function NewChatPage() {
|
|||
// Prepare assistant message
|
||||
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||
const batcher = new FrameBatchedUpdater();
|
||||
|
||||
const contentPartsState: ContentPartsState = {
|
||||
contentParts: [],
|
||||
|
|
@ -641,33 +656,30 @@ export default function NewChatPage() {
|
|||
throw new Error(`Backend error: ${response.status}`);
|
||||
}
|
||||
|
||||
const flushMessages = () => {
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
};
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(contentPartsState, parsed.delta);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
// Add tool call inline - this breaks the current text segment
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "tool-input-available": {
|
||||
// Update existing tool call's args, or add if not exists
|
||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
|
||||
} else {
|
||||
|
|
@ -679,23 +691,14 @@ export default function NewChatPage() {
|
|||
parsed.input || {}
|
||||
);
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-output-available": {
|
||||
// Update the tool call with its result
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
|
||||
markInterruptsCompleted(contentParts);
|
||||
// Handle podcast-specific logic
|
||||
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
|
||||
// Check if this is a podcast tool by looking at the content part
|
||||
const idx = toolCallIndices.get(parsed.toolCallId);
|
||||
if (idx !== undefined) {
|
||||
const part = contentParts[idx];
|
||||
|
|
@ -704,13 +707,7 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -718,14 +715,10 @@ export default function NewChatPage() {
|
|||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
if (didUpdate) {
|
||||
scheduleFlush();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -802,6 +795,8 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
|
||||
batcher.flush();
|
||||
|
||||
// Skip persistence for interrupted messages -- handleResume will persist the final version
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
if (contentParts.length > 0 && !wasInterrupted) {
|
||||
|
|
@ -831,6 +826,7 @@ export default function NewChatPage() {
|
|||
trackChatResponseReceived(searchSpaceId, currentThreadId);
|
||||
}
|
||||
} catch (error) {
|
||||
batcher.dispose();
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
// Request was cancelled by user - persist partial response if any content was received
|
||||
const hasContent = contentParts.some(
|
||||
|
|
@ -898,10 +894,11 @@ export default function NewChatPage() {
|
|||
setMentionedDocuments,
|
||||
setSidebarDocuments,
|
||||
setMessageDocumentsMap,
|
||||
setAgentCreatedDocuments,
|
||||
queryClient,
|
||||
currentThread,
|
||||
currentUser,
|
||||
disabledTools,
|
||||
updateChatTabTitle,
|
||||
]
|
||||
);
|
||||
|
||||
|
|
@ -929,6 +926,7 @@ export default function NewChatPage() {
|
|||
abortControllerRef.current = controller;
|
||||
|
||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||
const batcher = new FrameBatchedUpdater();
|
||||
|
||||
const contentPartsState: ContentPartsState = {
|
||||
contentParts: [],
|
||||
|
|
@ -1016,28 +1014,27 @@ export default function NewChatPage() {
|
|||
throw new Error(`Backend error: ${response.status}`);
|
||||
}
|
||||
|
||||
const flushMessages = () => {
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
};
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(contentPartsState, parsed.delta);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "tool-input-available":
|
||||
|
|
@ -1054,13 +1051,7 @@ export default function NewChatPage() {
|
|||
parsed.input || {}
|
||||
);
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "tool-output-available":
|
||||
|
|
@ -1068,27 +1059,17 @@ export default function NewChatPage() {
|
|||
result: parsed.output,
|
||||
});
|
||||
markInterruptsCompleted(contentParts);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "data-thinking-step": {
|
||||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
if (didUpdate) {
|
||||
scheduleFlush();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -1142,6 +1123,8 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
|
||||
batcher.flush();
|
||||
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
if (contentParts.length > 0) {
|
||||
try {
|
||||
|
|
@ -1158,6 +1141,7 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
} catch (error) {
|
||||
batcher.dispose();
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1303,6 +1287,7 @@ export default function NewChatPage() {
|
|||
toolCallIndices: new Map(),
|
||||
};
|
||||
const { contentParts, toolCallIndices } = contentPartsState;
|
||||
const batcher = new FrameBatchedUpdater();
|
||||
|
||||
// Add placeholder messages to UI
|
||||
// Always add back the user message (with new query for edit, or original content for reload)
|
||||
|
|
@ -1347,28 +1332,27 @@ export default function NewChatPage() {
|
|||
throw new Error(`Backend error: ${response.status}`);
|
||||
}
|
||||
|
||||
const flushMessages = () => {
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
};
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(contentPartsState, parsed.delta);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
scheduleFlush();
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "tool-input-available":
|
||||
|
|
@ -1383,13 +1367,7 @@ export default function NewChatPage() {
|
|||
parsed.input || {}
|
||||
);
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "tool-output-available":
|
||||
|
|
@ -1404,27 +1382,17 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
batcher.flush();
|
||||
break;
|
||||
|
||||
case "data-thinking-step": {
|
||||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContentForUI(contentPartsState, TOOLS_WITH_UI) }
|
||||
: m
|
||||
)
|
||||
);
|
||||
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
if (didUpdate) {
|
||||
scheduleFlush();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -1434,6 +1402,8 @@ export default function NewChatPage() {
|
|||
}
|
||||
}
|
||||
|
||||
batcher.flush();
|
||||
|
||||
// Persist messages after streaming completes
|
||||
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
|
||||
if (contentParts.length > 0) {
|
||||
|
|
@ -1475,6 +1445,7 @@ export default function NewChatPage() {
|
|||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
batcher.dispose();
|
||||
console.error("[NewChatPage] Regeneration error:", error);
|
||||
trackChatError(
|
||||
searchSpaceId,
|
||||
|
|
@ -1482,7 +1453,6 @@ export default function NewChatPage() {
|
|||
error instanceof Error ? error.message : "Unknown error"
|
||||
);
|
||||
toast.error("Failed to regenerate response. Please try again.");
|
||||
// Update assistant message with error
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue, useSetAtom } from "jotai";
|
||||
import { motion } from "motion/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
|
@ -13,19 +12,17 @@ import {
|
|||
globalNewLLMConfigsAtom,
|
||||
llmPreferencesAtom,
|
||||
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
import { searchSpaceSettingsDialogAtom } from "@/atoms/settings/settings-dialog.atoms";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
|
||||
export default function OnboardPage() {
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = Number(params.search_space_id);
|
||||
const setSearchSpaceSettingsDialog = useSetAtom(searchSpaceSettingsDialogAtom);
|
||||
|
||||
// Queries
|
||||
const {
|
||||
data: globalConfigs = [],
|
||||
|
|
@ -62,14 +59,12 @@ export default function OnboardPage() {
|
|||
preferences.document_summary_llm_id !== null &&
|
||||
preferences.document_summary_llm_id !== undefined;
|
||||
|
||||
// If onboarding is already complete, redirect immediately
|
||||
useEffect(() => {
|
||||
if (!preferencesLoading && isOnboardingComplete) {
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
}
|
||||
}, [preferencesLoading, isOnboardingComplete, router, searchSpaceId]);
|
||||
|
||||
// Auto-configure if global configs are available
|
||||
useEffect(() => {
|
||||
const autoConfigureWithGlobal = async () => {
|
||||
if (hasAttemptedAutoConfig.current) return;
|
||||
|
|
@ -77,7 +72,6 @@ export default function OnboardPage() {
|
|||
if (!globalConfigsLoaded) return;
|
||||
if (isOnboardingComplete) return;
|
||||
|
||||
// Only auto-configure if we have global configs
|
||||
if (globalConfigs.length > 0) {
|
||||
hasAttemptedAutoConfig.current = true;
|
||||
setIsAutoConfiguring(true);
|
||||
|
|
@ -97,7 +91,6 @@ export default function OnboardPage() {
|
|||
description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`,
|
||||
});
|
||||
|
||||
// Redirect to new-chat
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
} catch (error) {
|
||||
console.error("Auto-configuration failed:", error);
|
||||
|
|
@ -119,13 +112,10 @@ export default function OnboardPage() {
|
|||
router,
|
||||
]);
|
||||
|
||||
// Handle form submission
|
||||
const handleSubmit = async (formData: LLMConfigFormData) => {
|
||||
try {
|
||||
// Create the config
|
||||
const newConfig = await createConfig(formData);
|
||||
|
||||
// Auto-assign to all roles
|
||||
await updatePreferences({
|
||||
search_space_id: searchSpaceId,
|
||||
data: {
|
||||
|
|
@ -138,7 +128,6 @@ export default function OnboardPage() {
|
|||
description: "Redirecting to chat...",
|
||||
});
|
||||
|
||||
// Redirect to new-chat
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
} catch (error) {
|
||||
console.error("Failed to create config:", error);
|
||||
|
|
@ -150,124 +139,59 @@ export default function OnboardPage() {
|
|||
|
||||
const isSubmitting = isCreating || isUpdatingPreferences;
|
||||
|
||||
// Loading state
|
||||
if (globalConfigsLoading || preferencesLoading || isAutoConfiguring) {
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-b from-background to-muted/20 flex items-center justify-center">
|
||||
<motion.div
|
||||
initial={{ opacity: 0, scale: 0.95 }}
|
||||
animate={{ opacity: 1, scale: 1 }}
|
||||
className="text-center space-y-6"
|
||||
>
|
||||
<div className="relative">
|
||||
<div className="absolute inset-0 blur-3xl bg-gradient-to-r from-violet-500/20 to-cyan-500/20 rounded-full" />
|
||||
<div className="relative flex items-center justify-center w-24 h-24 mx-auto rounded-2xl bg-gradient-to-br from-violet-500 to-purple-600 shadow-2xl shadow-violet-500/25">
|
||||
<Spinner size="xl" className="text-white" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-2xl font-bold tracking-tight">
|
||||
{isAutoConfiguring ? "Setting up your AI..." : "Loading..."}
|
||||
</h2>
|
||||
<p className="text-muted-foreground">
|
||||
{isAutoConfiguring
|
||||
? "Auto-configuring with available settings"
|
||||
: "Please wait while we check your configuration"}
|
||||
const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring;
|
||||
useGlobalLoadingEffect(isLoading);
|
||||
|
||||
if (isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (globalConfigs.length > 0 && !isAutoConfiguring) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="h-screen flex flex-col items-center p-4 bg-background dark:bg-neutral-900 select-none overflow-hidden">
|
||||
<div className="w-full max-w-lg flex flex-col min-h-0 h-full gap-6 py-8">
|
||||
{/* Header */}
|
||||
<div className="text-center space-y-3 shrink-0">
|
||||
<Logo className="w-12 h-12 mx-auto" />
|
||||
<div className="space-y-1">
|
||||
<h1 className="text-2xl font-semibold tracking-tight">Configure Your AI</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Add your LLM provider to get started with SurfSense
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex justify-center gap-1">
|
||||
{[0, 1, 2].map((i) => (
|
||||
<motion.div
|
||||
key={i}
|
||||
className="w-2 h-2 rounded-full bg-violet-500"
|
||||
animate={{ scale: [1, 1.5, 1], opacity: [0.5, 1, 0.5] }}
|
||||
transition={{ duration: 1, repeat: Infinity, delay: i * 0.2 }}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
</div>
|
||||
|
||||
// If global configs exist but auto-config failed, show simple message
|
||||
if (globalConfigs.length > 0 && !isAutoConfiguring) {
|
||||
return null; // Will redirect via useEffect
|
||||
}
|
||||
{/* Form card */}
|
||||
<div className="rounded-xl border bg-background dark:bg-neutral-900 flex-1 min-h-0 overflow-y-auto px-6 py-6">
|
||||
<LLMConfigForm
|
||||
searchSpaceId={searchSpaceId}
|
||||
onSubmit={handleSubmit}
|
||||
mode="create"
|
||||
showAdvanced={true}
|
||||
formId="onboard-config-form"
|
||||
initialData={{
|
||||
citations_enabled: true,
|
||||
use_default_system_instructions: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
// No global configs - show the config form
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-b from-background via-background to-muted/30">
|
||||
<div className="container mx-auto px-4 py-8 md:py-12 max-w-3xl">
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
className="space-y-8"
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="text-center space-y-4">
|
||||
<motion.div
|
||||
initial={{ scale: 0 }}
|
||||
animate={{ scale: 1 }}
|
||||
transition={{ type: "spring", delay: 0.2 }}
|
||||
className="relative inline-block"
|
||||
>
|
||||
<Logo className="w-20 h-20 mx-auto rounded-full" />
|
||||
</motion.div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<h1 className="text-3xl font-bold tracking-tight">Configure Your AI</h1>
|
||||
<p className="text-muted-foreground text-lg">
|
||||
Add your LLM provider to get started with SurfSense
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Config Form */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ delay: 0.3 }}
|
||||
{/* Footer */}
|
||||
<div className="text-center space-y-4 shrink-0">
|
||||
<Button
|
||||
type="submit"
|
||||
form="onboard-config-form"
|
||||
disabled={isSubmitting}
|
||||
className="relative text-sm h-9 min-w-[180px]"
|
||||
>
|
||||
<Card className="border-2 border-muted shadow-xl overflow-hidden">
|
||||
<CardHeader className="pb-4">
|
||||
<CardTitle className="text-xl">LLM Configuration</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<LLMConfigForm
|
||||
searchSpaceId={searchSpaceId}
|
||||
onSubmit={handleSubmit}
|
||||
isSubmitting={isSubmitting}
|
||||
mode="create"
|
||||
showAdvanced={true}
|
||||
submitLabel="Start Using SurfSense"
|
||||
initialData={{
|
||||
citations_enabled: true,
|
||||
use_default_system_instructions: true,
|
||||
}}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
|
||||
{/* Footer note */}
|
||||
<motion.p
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ delay: 0.5 }}
|
||||
className="text-center text-sm text-muted-foreground"
|
||||
>
|
||||
You can add more configurations and customize settings anytime in{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setSearchSpaceSettingsDialog({ open: true, initialTab: "general" })}
|
||||
className="text-violet-500 hover:underline"
|
||||
>
|
||||
Settings
|
||||
</button>
|
||||
</motion.p>
|
||||
</motion.div>
|
||||
<span className={isSubmitting ? "opacity-0" : ""}>Start Using SurfSense</span>
|
||||
{isSubmitting && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
<p className="text-xs text-muted-foreground">You can add more configurations later</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
"use client";
|
||||
|
||||
import { CircleSlash2 } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useParams } from "next/navigation";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
|
||||
export default function PurchaseCancelPage() {
|
||||
const params = useParams();
|
||||
const searchSpaceId = String(params.search_space_id ?? "");
|
||||
|
||||
return (
|
||||
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
||||
<Card className="w-full max-w-lg">
|
||||
<CardHeader className="text-center">
|
||||
<CircleSlash2 className="mx-auto h-10 w-10 text-muted-foreground" />
|
||||
<CardTitle className="text-2xl">Checkout canceled</CardTitle>
|
||||
<CardDescription>
|
||||
No charge was made and your current pages are unchanged.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="text-center text-sm text-muted-foreground">
|
||||
You can return to the pricing options and try again whenever you're ready.
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-col gap-2 sm:flex-row">
|
||||
<Button asChild className="w-full">
|
||||
<Link href={`/dashboard/${searchSpaceId}/more-pages`}>Back to Buy Pages</Link>
|
||||
</Button>
|
||||
<Button asChild variant="outline" className="w-full">
|
||||
<Link href={`/dashboard/${searchSpaceId}/new-chat`}>Back to Dashboard</Link>
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"use client";
|
||||
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { CheckCircle2 } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
|
||||
export default function PurchaseSuccessPage() {
|
||||
const params = useParams();
|
||||
const queryClient = useQueryClient();
|
||||
const searchSpaceId = String(params.search_space_id ?? "");
|
||||
|
||||
useEffect(() => {
|
||||
void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
|
||||
}, [queryClient]);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
||||
<Card className="w-full max-w-lg">
|
||||
<CardHeader className="text-center">
|
||||
<CheckCircle2 className="mx-auto h-10 w-10 text-emerald-500" />
|
||||
<CardTitle className="text-2xl">Purchase complete</CardTitle>
|
||||
<CardDescription>
|
||||
Your additional pages are being applied to your account now.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-3 text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Your sidebar usage meter should refresh automatically in a moment.
|
||||
</p>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-col gap-2">
|
||||
<Button asChild className="w-full">
|
||||
<Link href={`/dashboard/${searchSpaceId}/new-chat`}>Back to Dashboard</Link>
|
||||
</Button>
|
||||
<Button asChild variant="outline" className="w-full">
|
||||
<Link href={`/dashboard/${searchSpaceId}/more-pages`}>Buy More Pages</Link>
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -308,7 +308,8 @@ export function TeamContent({ searchSpaceId }: TeamContentProps) {
|
|||
{invitesLoading ? (
|
||||
<Skeleton className="h-9 w-32 rounded-md" />
|
||||
) : (
|
||||
canInvite && activeInvites.length > 0 && (
|
||||
canInvite &&
|
||||
activeInvites.length > 0 && (
|
||||
<AllInvitesDialog invites={activeInvites} onRevokeInvite={handleRevokeInvite} />
|
||||
)
|
||||
)}
|
||||
|
|
@ -763,7 +764,7 @@ function CreateInviteDialog({
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter className="gap-3 sm:gap-2">
|
||||
<DialogFooter>
|
||||
<Button variant="secondary" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertTriangle, Copy, Globe, Sparkles } from "lucide-react";
|
||||
import { useCallback, useState } from "react";
|
||||
import { copyPromptMutationAtom } from "@/atoms/prompts/prompts-mutation.atoms";
|
||||
import { publicPromptsAtom } from "@/atoms/prompts/prompts-query.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
export function CommunityPromptsContent() {
|
||||
const { data: prompts, isLoading, isError } = useAtomValue(publicPromptsAtom);
|
||||
const { mutateAsync: copyPrompt } = useAtomValue(copyPromptMutationAtom);
|
||||
const [copyingIds, setCopyingIds] = useState<Set<number>>(new Set());
|
||||
const [expandedId, setExpandedId] = useState<number | null>(null);
|
||||
|
||||
const handleCopy = useCallback(
|
||||
async (id: number) => {
|
||||
setCopyingIds((prev) => new Set(prev).add(id));
|
||||
try {
|
||||
await copyPrompt(id);
|
||||
} catch {
|
||||
// toast handled by mutation atom
|
||||
} finally {
|
||||
setCopyingIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.delete(id);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
},
|
||||
[copyPrompt]
|
||||
);
|
||||
|
||||
const list = prompts ?? [];
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner className="size-6" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-destructive/40 p-8 text-center">
|
||||
<AlertTriangle className="mx-auto size-8 text-destructive/60" />
|
||||
<p className="mt-2 text-sm text-destructive">Failed to load community prompts</p>
|
||||
<p className="text-xs text-muted-foreground">Please try refreshing the page.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6 min-w-0 overflow-hidden">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Prompts shared by other users. Add any to your collection with one click.
|
||||
</p>
|
||||
|
||||
{list.length === 0 && (
|
||||
<div className="rounded-lg border border-dashed border-border/60 p-8 text-center">
|
||||
<Globe className="mx-auto size-8 text-muted-foreground/40" />
|
||||
<p className="mt-2 text-sm text-muted-foreground">No community prompts yet</p>
|
||||
<p className="text-xs text-muted-foreground/60">
|
||||
Share your own prompts from the My Prompts tab
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{list.length > 0 && (
|
||||
<div className="space-y-2">
|
||||
{list.map((prompt) => (
|
||||
<div
|
||||
key={prompt.id}
|
||||
className="group flex items-start gap-3 rounded-lg border border-border/60 bg-card p-4"
|
||||
>
|
||||
<div className="mt-0.5 shrink-0 text-muted-foreground">
|
||||
<Sparkles className="size-4" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium">{prompt.name}</span>
|
||||
<span className="rounded-full border px-2 py-0.5 text-[10px] text-muted-foreground">
|
||||
{prompt.mode}
|
||||
</span>
|
||||
{prompt.author_name && (
|
||||
<span className="text-[11px] text-muted-foreground/60">
|
||||
by {prompt.author_name}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<p
|
||||
className={`mt-1 text-xs text-muted-foreground ${expandedId === prompt.id ? "whitespace-pre-wrap" : "line-clamp-2"}`}
|
||||
>
|
||||
{prompt.prompt}
|
||||
</p>
|
||||
{prompt.prompt.length > 100 && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExpandedId(expandedId === prompt.id ? null : prompt.id)}
|
||||
className="mt-1 text-[11px] text-primary hover:underline cursor-pointer"
|
||||
>
|
||||
{expandedId === prompt.id ? "See less" : "See more"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="shrink-0 gap-1.5"
|
||||
disabled={copyingIds.has(prompt.id)}
|
||||
onClick={() => handleCopy(prompt.id)}
|
||||
>
|
||||
{copyingIds.has(prompt.id) ? (
|
||||
<Spinner className="size-3" />
|
||||
) : (
|
||||
<Copy className="size-3" />
|
||||
)}
|
||||
Add to mine
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertTriangle, Globe, Lock, PenLine, Plus, Sparkles, Trash2 } from "lucide-react";
|
||||
import { useCallback, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createPromptMutationAtom,
|
||||
deletePromptMutationAtom,
|
||||
updatePromptMutationAtom,
|
||||
} from "@/atoms/prompts/prompts-mutation.atoms";
|
||||
import { promptsAtom } from "@/atoms/prompts/prompts-query.atoms";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import type { PromptRead } from "@/contracts/types/prompts.types";
|
||||
|
||||
interface PromptFormData {
|
||||
name: string;
|
||||
prompt: string;
|
||||
mode: "transform" | "explore";
|
||||
is_public: boolean;
|
||||
}
|
||||
|
||||
const EMPTY_FORM: PromptFormData = { name: "", prompt: "", mode: "transform", is_public: false };
|
||||
|
||||
export function PromptsContent() {
|
||||
const { data: prompts, isLoading, isError } = useAtomValue(promptsAtom);
|
||||
const { mutateAsync: createPrompt } = useAtomValue(createPromptMutationAtom);
|
||||
const { mutateAsync: updatePrompt } = useAtomValue(updatePromptMutationAtom);
|
||||
const { mutateAsync: deletePrompt } = useAtomValue(deletePromptMutationAtom);
|
||||
|
||||
const [showForm, setShowForm] = useState(false);
|
||||
const [editingId, setEditingId] = useState<number | null>(null);
|
||||
const [formData, setFormData] = useState<PromptFormData>(EMPTY_FORM);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [expandedId, setExpandedId] = useState<number | null>(null);
|
||||
const [deleteTarget, setDeleteTarget] = useState<number | null>(null);
|
||||
const [togglingPublicIds, setTogglingPublicIds] = useState<Set<number>>(new Set());
|
||||
|
||||
const handleSave = useCallback(async () => {
|
||||
if (!formData.name.trim() || !formData.prompt.trim()) {
|
||||
toast.error("Name and prompt are required");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSaving(true);
|
||||
try {
|
||||
if (editingId !== null) {
|
||||
await updatePrompt({ id: editingId, ...formData });
|
||||
} else {
|
||||
await createPrompt(formData);
|
||||
}
|
||||
setShowForm(false);
|
||||
setFormData(EMPTY_FORM);
|
||||
setEditingId(null);
|
||||
} catch {
|
||||
// toast handled by mutation atoms
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}, [formData, editingId, createPrompt, updatePrompt]);
|
||||
|
||||
const handleEdit = useCallback((prompt: PromptRead) => {
|
||||
setFormData({
|
||||
name: prompt.name,
|
||||
prompt: prompt.prompt,
|
||||
mode: prompt.mode as "transform" | "explore",
|
||||
is_public: prompt.is_public,
|
||||
});
|
||||
setEditingId(prompt.id);
|
||||
setShowForm(true);
|
||||
}, []);
|
||||
|
||||
const handleConfirmDelete = useCallback(async () => {
|
||||
if (deleteTarget === null) return;
|
||||
try {
|
||||
await deletePrompt(deleteTarget);
|
||||
} catch {
|
||||
// toast handled by mutation atom
|
||||
} finally {
|
||||
setDeleteTarget(null);
|
||||
}
|
||||
}, [deleteTarget, deletePrompt]);
|
||||
|
||||
const handleTogglePublic = useCallback(
|
||||
async (prompt: PromptRead) => {
|
||||
if (togglingPublicIds.has(prompt.id)) return;
|
||||
setTogglingPublicIds((prev) => new Set(prev).add(prompt.id));
|
||||
try {
|
||||
await updatePrompt({ id: prompt.id, is_public: !prompt.is_public });
|
||||
} catch {
|
||||
// toast handled by mutation atom
|
||||
} finally {
|
||||
setTogglingPublicIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.delete(prompt.id);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
},
|
||||
[updatePrompt, togglingPublicIds]
|
||||
);
|
||||
|
||||
const handleCancel = useCallback(() => {
|
||||
setShowForm(false);
|
||||
setFormData(EMPTY_FORM);
|
||||
setEditingId(null);
|
||||
}, []);
|
||||
|
||||
const list = prompts ?? [];
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner className="size-6" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-destructive/40 p-8 text-center">
|
||||
<AlertTriangle className="mx-auto size-8 text-destructive/60" />
|
||||
<p className="mt-2 text-sm text-destructive">Failed to load prompts</p>
|
||||
<p className="text-xs text-muted-foreground">Please try refreshing the page.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6 min-w-0 overflow-hidden">
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Create prompt templates triggered with{" "}
|
||||
<kbd className="rounded border bg-muted px-1.5 py-0.5 text-xs font-mono">/</kbd> in the
|
||||
chat composer.
|
||||
</p>
|
||||
{!showForm && (
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setShowForm(true);
|
||||
setEditingId(null);
|
||||
setFormData(EMPTY_FORM);
|
||||
}}
|
||||
className="shrink-0 gap-1.5"
|
||||
>
|
||||
<Plus className="size-3.5" />
|
||||
New
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{showForm && (
|
||||
<div className="rounded-lg border border-border/60 bg-card p-6 space-y-4">
|
||||
<h3 className="text-sm font-semibold tracking-tight">
|
||||
{editingId !== null ? "Edit prompt" : "New prompt"}
|
||||
</h3>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="prompt-name">Name</Label>
|
||||
<Input
|
||||
id="prompt-name"
|
||||
value={formData.name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))}
|
||||
placeholder="e.g. Fix grammar"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="prompt-template">Prompt template</Label>
|
||||
<textarea
|
||||
id="prompt-template"
|
||||
value={formData.prompt}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, prompt: e.target.value }))}
|
||||
placeholder="e.g. Fix the grammar in the following text:\n\n{selection}"
|
||||
rows={4}
|
||||
className="w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm outline-none resize-none focus:ring-1 focus:ring-ring"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Use{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-[11px]">
|
||||
{"{selection}"}
|
||||
</code>{" "}
|
||||
to insert the input text. If omitted, the text is appended automatically.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="prompt-mode">Mode</Label>
|
||||
<select
|
||||
id="prompt-mode"
|
||||
value={formData.mode}
|
||||
onChange={(e) =>
|
||||
setFormData((p) => ({ ...p, mode: e.target.value as "transform" | "explore" }))
|
||||
}
|
||||
className="w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm outline-none focus:ring-1 focus:ring-ring"
|
||||
>
|
||||
<option value="transform">Transform — rewrites or modifies your text</option>
|
||||
<option value="explore">Explore — answers a question about your text</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<Switch
|
||||
id="prompt-public"
|
||||
checked={formData.is_public}
|
||||
onCheckedChange={(checked) => setFormData((p) => ({ ...p, is_public: checked }))}
|
||||
/>
|
||||
<Label htmlFor="prompt-public" className="text-sm font-normal">
|
||||
Share with community
|
||||
</Label>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-end gap-2 pt-2">
|
||||
<Button variant="ghost" size="sm" onClick={handleCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" onClick={handleSave} disabled={isSaving} className="relative">
|
||||
<span className={isSaving ? "opacity-0" : ""}>
|
||||
{editingId !== null ? "Update" : "Create"}
|
||||
</span>
|
||||
{isSaving && <Spinner className="size-3.5 absolute" />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{list.length === 0 && !showForm && (
|
||||
<div className="rounded-lg border border-dashed border-border/60 p-8 text-center">
|
||||
<Sparkles className="mx-auto size-8 text-muted-foreground/40" />
|
||||
<p className="mt-2 text-sm text-muted-foreground">No prompts yet</p>
|
||||
<p className="text-xs text-muted-foreground/60">
|
||||
Create prompts to quickly transform or explore text with /
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{list.length > 0 && (
|
||||
<div className="space-y-2">
|
||||
{list.map((prompt) => (
|
||||
<div
|
||||
key={prompt.id}
|
||||
className="group flex items-start gap-3 rounded-lg border border-border/60 bg-card p-4"
|
||||
>
|
||||
<div className="mt-0.5 shrink-0 text-muted-foreground">
|
||||
<Sparkles className="size-4" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium">{prompt.name}</span>
|
||||
<span className="rounded-full border px-2 py-0.5 text-[10px] text-muted-foreground">
|
||||
{prompt.mode}
|
||||
</span>
|
||||
{prompt.is_public && (
|
||||
<span className="flex items-center gap-1 rounded-full border border-primary/20 bg-primary/5 px-2 py-0.5 text-[10px] text-primary">
|
||||
<Globe className="size-2.5" />
|
||||
Public
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<p
|
||||
className={`mt-1 text-xs text-muted-foreground ${expandedId === prompt.id ? "whitespace-pre-wrap" : "line-clamp-2"}`}
|
||||
>
|
||||
{prompt.prompt}
|
||||
</p>
|
||||
{prompt.prompt.length > 100 && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExpandedId(expandedId === prompt.id ? null : prompt.id)}
|
||||
className="mt-1 text-[11px] text-primary hover:underline cursor-pointer"
|
||||
>
|
||||
{expandedId === prompt.id ? "See less" : "See more"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<div className="hidden group-hover:flex items-center gap-1 shrink-0">
|
||||
<button
|
||||
type="button"
|
||||
title={prompt.is_public ? "Make private" : "Share with community"}
|
||||
onClick={() => handleTogglePublic(prompt)}
|
||||
disabled={togglingPublicIds.has(prompt.id)}
|
||||
className="flex items-center justify-center size-7 rounded-md text-muted-foreground hover:text-foreground hover:bg-accent transition-colors disabled:opacity-50 disabled:pointer-events-none"
|
||||
>
|
||||
{togglingPublicIds.has(prompt.id) ? (
|
||||
<Spinner className="size-3.5" />
|
||||
) : prompt.is_public ? (
|
||||
<Lock className="size-3.5" />
|
||||
) : (
|
||||
<Globe className="size-3.5" />
|
||||
)}
|
||||
</button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7"
|
||||
onClick={() => handleEdit(prompt)}
|
||||
>
|
||||
<PenLine className="size-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 text-destructive hover:text-destructive"
|
||||
onClick={() => setDeleteTarget(prompt.id)}
|
||||
>
|
||||
<Trash2 className="size-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AlertDialog
|
||||
open={deleteTarget !== null}
|
||||
onOpenChange={(open) => !open && setDeleteTarget(null)}
|
||||
>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete prompt</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This action cannot be undone. The prompt will be permanently removed.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction onClick={handleConfirmDelete}>Delete</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { Receipt } from "lucide-react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import type { PagePurchase, PagePurchaseStatus } from "@/contracts/types/stripe.types";
|
||||
import { stripeApiService } from "@/lib/apis/stripe-api.service";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const STATUS_STYLES: Record<PagePurchaseStatus, { label: string; className: string }> = {
|
||||
completed: {
|
||||
label: "Completed",
|
||||
className: "bg-emerald-600 text-white border-transparent hover:bg-emerald-600",
|
||||
},
|
||||
pending: {
|
||||
label: "Pending",
|
||||
className: "bg-yellow-600 text-white border-transparent hover:bg-yellow-600",
|
||||
},
|
||||
failed: {
|
||||
label: "Failed",
|
||||
className: "bg-destructive text-white border-transparent hover:bg-destructive",
|
||||
},
|
||||
};
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
return new Date(iso).toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
});
|
||||
}
|
||||
|
||||
function formatAmount(purchase: PagePurchase): string {
|
||||
if (purchase.amount_total == null) return "—";
|
||||
const dollars = purchase.amount_total / 100;
|
||||
const currency = (purchase.currency ?? "usd").toUpperCase();
|
||||
return `$${dollars.toFixed(2)} ${currency}`;
|
||||
}
|
||||
|
||||
export function PurchaseHistoryContent() {
|
||||
const { data, isLoading } = useQuery({
|
||||
queryKey: ["stripe-purchases"],
|
||||
queryFn: () => stripeApiService.getPurchases(),
|
||||
});
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const purchases = data?.purchases ?? [];
|
||||
|
||||
if (purchases.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center gap-2 py-16 text-center">
|
||||
<Receipt className="h-8 w-8 text-muted-foreground" />
|
||||
<p className="text-sm font-medium">No purchases yet</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Your page-pack purchases will appear here after checkout.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Date</TableHead>
|
||||
<TableHead className="text-right">Pages</TableHead>
|
||||
<TableHead className="text-right">Amount</TableHead>
|
||||
<TableHead className="text-center">Status</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{purchases.map((p) => {
|
||||
const style = STATUS_STYLES[p.status];
|
||||
return (
|
||||
<TableRow key={p.id}>
|
||||
<TableCell className="text-sm">{formatDate(p.created_at)}</TableCell>
|
||||
<TableCell className="text-right tabular-nums text-sm">
|
||||
{p.pages_granted.toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell className="text-right tabular-nums text-sm">
|
||||
{formatAmount(p)}
|
||||
</TableCell>
|
||||
<TableCell className="text-center">
|
||||
<Badge className={cn("text-[10px]", style.className)}>{style.label}</Badge>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
<p className="text-center text-xs text-muted-foreground">
|
||||
Showing your {purchases.length} most recent purchase{purchases.length !== 1 ? "s" : ""}.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
import { queryClient } from "@/lib/query-client/client";
|
||||
|
||||
interface DashboardLayoutProps {
|
||||
children: React.ReactNode;
|
||||
|
|
@ -22,6 +24,7 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
|||
redirectToLogin();
|
||||
return;
|
||||
}
|
||||
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
|
||||
setIsCheckingAuth(false);
|
||||
}, []);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle, Plus, Search } from "lucide-react";
|
||||
import { motion } from "motion/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useEffect, useState } from "react";
|
||||
import { searchSpacesAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
|
|
@ -89,6 +89,7 @@ function EmptyState({ onCreateClick }: { onCreateClick: () => void }) {
|
|||
|
||||
export default function DashboardPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const [showCreateDialog, setShowCreateDialog] = useState(false);
|
||||
|
||||
const t = useTranslations("dashboard");
|
||||
|
|
@ -98,9 +99,11 @@ export default function DashboardPage() {
|
|||
if (isLoading) return;
|
||||
|
||||
if (searchSpaces.length > 0) {
|
||||
router.replace(`/dashboard/${searchSpaces[0].id}/new-chat`);
|
||||
const params = searchParams.toString();
|
||||
const query = params ? `?${params}` : "";
|
||||
router.replace(`/dashboard/${searchSpaces[0].id}/new-chat${query}`);
|
||||
}
|
||||
}, [isLoading, searchSpaces, router]);
|
||||
}, [isLoading, searchSpaces, router, searchParams]);
|
||||
|
||||
// Show loading while fetching or while we have spaces and are about to redirect
|
||||
const shouldShowLoading = isLoading || searchSpaces.length > 0;
|
||||
|
|
|
|||
|
|
@ -133,6 +133,12 @@ export default function sitemap(): MetadataRoute.Sitemap {
|
|||
changeFrequency: "daily",
|
||||
priority: 0.8,
|
||||
},
|
||||
{
|
||||
url: "https://www.surfsense.com/docs/connectors/dropbox",
|
||||
lastModified,
|
||||
changeFrequency: "daily",
|
||||
priority: 0.8,
|
||||
},
|
||||
{
|
||||
url: "https://www.surfsense.com/docs/connectors/elasticsearch",
|
||||
lastModified,
|
||||
|
|
@ -181,6 +187,12 @@ export default function sitemap(): MetadataRoute.Sitemap {
|
|||
changeFrequency: "daily",
|
||||
priority: 0.8,
|
||||
},
|
||||
{
|
||||
url: "https://www.surfsense.com/docs/connectors/microsoft-onedrive",
|
||||
lastModified,
|
||||
changeFrequency: "daily",
|
||||
priority: 0.8,
|
||||
},
|
||||
{
|
||||
url: "https://www.surfsense.com/docs/connectors/microsoft-teams",
|
||||
lastModified,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import { atomWithMutation } from "jotai-tanstack-query";
|
|||
import { toast } from "sonner";
|
||||
import type {
|
||||
CreateImageGenConfigRequest,
|
||||
CreateImageGenConfigResponse,
|
||||
DeleteImageGenConfigResponse,
|
||||
GetImageGenConfigsResponse,
|
||||
UpdateImageGenConfigRequest,
|
||||
UpdateImageGenConfigResponse,
|
||||
|
|
@ -23,14 +25,14 @@ export const createImageGenConfigMutationAtom = atomWithMutation((get) => {
|
|||
mutationFn: async (request: CreateImageGenConfigRequest) => {
|
||||
return imageGenConfigApiService.createConfig(request);
|
||||
},
|
||||
onSuccess: () => {
|
||||
toast.success("Image model configuration created");
|
||||
onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => {
|
||||
toast.success(`${request.name} created`);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to create image model configuration");
|
||||
toast.error(error.message || "Failed to create image model");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -48,7 +50,7 @@ export const updateImageGenConfigMutationAtom = atomWithMutation((get) => {
|
|||
return imageGenConfigApiService.updateConfig(request);
|
||||
},
|
||||
onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => {
|
||||
toast.success("Image model configuration updated");
|
||||
toast.success(`${request.data.name ?? "Configuration"} updated`);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
|
|
@ -57,7 +59,7 @@ export const updateImageGenConfigMutationAtom = atomWithMutation((get) => {
|
|||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to update image model configuration");
|
||||
toast.error(error.message || "Failed to update image model");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -71,21 +73,21 @@ export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => {
|
|||
return {
|
||||
mutationKey: ["image-gen-configs", "delete"],
|
||||
enabled: !!searchSpaceId,
|
||||
mutationFn: async (id: number) => {
|
||||
return imageGenConfigApiService.deleteConfig(id);
|
||||
mutationFn: async (request: { id: number; name: string }) => {
|
||||
return imageGenConfigApiService.deleteConfig(request.id);
|
||||
},
|
||||
onSuccess: (_, id: number) => {
|
||||
toast.success("Image model configuration deleted");
|
||||
onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => {
|
||||
toast.success(`${request.name} deleted`);
|
||||
queryClient.setQueryData(
|
||||
cacheKeys.imageGenConfigs.all(Number(searchSpaceId)),
|
||||
(oldData: GetImageGenConfigsResponse | undefined) => {
|
||||
if (!oldData) return oldData;
|
||||
return oldData.filter((config) => config.id !== id);
|
||||
return oldData.filter((config) => config.id !== request.id);
|
||||
}
|
||||
);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to delete image model configuration");
|
||||
toast.error(error.message || "Failed to delete image model");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ import { atomWithMutation } from "jotai-tanstack-query";
|
|||
import { toast } from "sonner";
|
||||
import type {
|
||||
CreateNewLLMConfigRequest,
|
||||
CreateNewLLMConfigResponse,
|
||||
DeleteNewLLMConfigRequest,
|
||||
DeleteNewLLMConfigResponse,
|
||||
GetNewLLMConfigsResponse,
|
||||
UpdateLLMPreferencesRequest,
|
||||
UpdateNewLLMConfigRequest,
|
||||
|
|
@ -25,14 +27,14 @@ export const createNewLLMConfigMutationAtom = atomWithMutation((get) => {
|
|||
mutationFn: async (request: CreateNewLLMConfigRequest) => {
|
||||
return newLLMConfigApiService.createConfig(request);
|
||||
},
|
||||
onSuccess: () => {
|
||||
toast.success("Configuration created successfully");
|
||||
onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => {
|
||||
toast.success(`${request.name} created`);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to create configuration");
|
||||
toast.error(error.message || "Failed to create LLM model");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -50,7 +52,7 @@ export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => {
|
|||
return newLLMConfigApiService.updateConfig(request);
|
||||
},
|
||||
onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => {
|
||||
toast.success("Configuration updated successfully");
|
||||
toast.success(`${request.data.name ?? "Configuration"} updated`);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
|
||||
});
|
||||
|
|
@ -59,7 +61,7 @@ export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => {
|
|||
});
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to update configuration");
|
||||
toast.error(error.message || "Failed to update");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -73,11 +75,14 @@ export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => {
|
|||
return {
|
||||
mutationKey: ["new-llm-configs", "delete"],
|
||||
enabled: !!searchSpaceId,
|
||||
mutationFn: async (request: DeleteNewLLMConfigRequest) => {
|
||||
return newLLMConfigApiService.deleteConfig(request);
|
||||
mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => {
|
||||
return newLLMConfigApiService.deleteConfig({ id: request.id });
|
||||
},
|
||||
onSuccess: (_, request: DeleteNewLLMConfigRequest) => {
|
||||
toast.success("Configuration deleted successfully");
|
||||
onSuccess: (
|
||||
_: DeleteNewLLMConfigResponse,
|
||||
request: DeleteNewLLMConfigRequest & { name: string }
|
||||
) => {
|
||||
toast.success(`${request.name} deleted`);
|
||||
queryClient.setQueryData(
|
||||
cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
|
||||
(oldData: GetNewLLMConfigsResponse | undefined) => {
|
||||
|
|
@ -87,7 +92,7 @@ export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => {
|
|||
);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(error.message || "Failed to delete configuration");
|
||||
toast.error(error.message || "Failed to delete");
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue