diff --git a/.gitignore b/.gitignore index a99954efe..d086673db 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,5 @@ surfsense_web/test-results/ surfsense_web/blob-report/ content_research/ +automation-design-plan.md +automation-frontend-builder-plan.md diff --git a/README.es.md b/README.es.md index dea86a793..ea7623617 100644 --- a/README.es.md +++ b/README.es.md @@ -41,6 +41,7 @@ NotebookLM es una de las mejores y más útiles plataformas de IA que existen, p - **Sin Dependencia de Proveedores** - Configura cualquier modelo LLM, de imagen, TTS y STT. - **25+ Fuentes de Datos Externas** - Agrega tus fuentes desde Google Drive, OneDrive, Dropbox, Notion y muchos otros servicios externos. - **Soporte Multijugador en Tiempo Real** - Trabaja fácilmente con los miembros de tu equipo en un notebook compartido. +- **Automatizaciones y Agentes de IA** - Ejecuta agentes de IA según una programación o actívalos en el momento en que un documento llega a una carpeta, y luego escribe los resultados de vuelta en Notion, Slack, Linear y Drive. Crea automatizaciones sin código solo describiéndolas en el chat. - **Aplicación de Escritorio** - Obtén asistencia de IA en cualquier aplicación con Quick Assist, General Assist, Screenshot Assist y sincronización de carpetas locales. ...y más por venir. @@ -76,48 +77,118 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7 4. Una vez que todo esté indexado, pregunta lo que quieras (Casos de uso): - - Aplicación de Escritorio — General Assist + **Aplicación de Escritorio** (extras nativos, además de todo lo de abajo, no un conjunto aparte) + + - General Assist: abre SurfSense al instante desde cualquier aplicación con un atajo global.

General Assist

- - Aplicación de Escritorio — Quick Assist + - Quick Assist: selecciona texto en cualquier lugar y pide a la IA que lo explique, reescriba o actúe sobre él.

Quick Assist

- - Aplicación de Escritorio — Screenshot Assist + - Screenshot Assist: captura cualquier región de tu pantalla y pregunta a la IA sobre lo que contiene.

Screenshot Assist

- - Aplicación de Escritorio — Watch Local Folder + - Watch Local Folder: sincroniza automáticamente una carpeta local con tu base de conocimiento. Ideal para bóvedas de Obsidian.

Watch Local Folder

- - Generación de videos + **Estudio de Entregables** -

Generación de Videos

+ - AI Report Generator: genera informes de investigación con citas y expórtalos a PDF, DOCX, HTML, LaTeX, EPUB, ODT o texto plano. - - Búsqueda básica y citaciones +

AI Report Generator

-

Búsqueda y Citación

+ - AI Podcast Generator: convierte cualquier documento o carpeta en un pódcast de IA con dos presentadores en menos de 20 segundos. - - QNA con mención de documentos +

AI Podcast Generator

-

QNA con Mención de Documentos

-

QNA con Mención de Documentos

+ - AI Presentation & Video Maker: crea presentaciones editables y videos narrados a partir de tus fuentes. - - Generación de informes y exportaciones (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto plano) +

AI Presentation and Video Maker

-

Generación de Informes

+ - AI Image Generator: genera imágenes de alta calidad directamente desde tus chats y documentos. - - Generación de podcasts +

AI Image Generator

-

Generación de Podcasts

+ - AI Resume Builder: adapta tu currículum existente a cualquier descripción de empleo y supera el ATS. + Prueba indicaciones como estas: - - Generación de imágenes + - "Adapta mi currículum a esta descripción de empleo para superar el ATS y conseguir una entrevista." + - "Optimiza mi currículum para ATS haciendo coincidir las palabras clave de esta oferta." + - "Reescribe los puntos de mi currículum para resaltar las habilidades que pide este puesto." + - "Compara mi currículum con esta descripción de empleo y enumera las carencias a corregir." + - "Escribe una carta de presentación a juego con mi currículum y esta descripción de empleo." -

Generación de Imágenes

+ **Búsqueda y Chat** - - Y más próximamente. + - Chat With Your PDFs & Docs: haz preguntas sobre todos tus archivos y obtén respuestas con citas en línea. + +

Chat With Your PDFs and Docs

+ + - AI Search With Citations: búsqueda híbrida semántica y por palabras clave en toda tu base de conocimiento. + +

AI Search With Citations

+ + - Collaborative AI Chat: trabaja en conversaciones de IA con tu equipo en tiempo real. + +

Collaborative AI Chat

+ + - Comments & Mentions: comenta y menciona a tus compañeros en cualquier mensaje de IA. + +

Comments and Mentions

+ + **Conectores e Integraciones** + + - Connect & Sync Your Tools: sincroniza Notion, Slack, Google Drive, Gmail, GitHub, Linear y más de 25 fuentes en un único corpus consultable. + +

Connect and Sync Your Tools

+ + - Chat With Uploaded Files: sube PDFs, documentos de Office, imágenes y audio. Consultables al instante. + +

Chat With Uploaded Files

+ + - Connector Write-Back: deja que el agente publique los resultados de vuelta en Notion, Slack, Linear y Drive. + Prueba indicaciones como estas: + + - "Publica este resumen de investigación en mi espacio de Notion." + - "Envía estos elementos de acción de la reunión a nuestro canal de Slack." + - "Crea un ticket de Jira a partir de este informe de error." + - "Abre una incidencia en Linear a partir de esta solicitud de función." + - "Guarda este informe generado en Google Drive como un documento." + + - Obsidian & Knowledge Base Sync: mantén tu bóveda de Obsidian y tu base de conocimiento personal sincronizadas. + + **Automatizaciones** + + - Scheduled AI Workflows: ejecuta un agente según una programación: resúmenes diarios, boletines semanales, informes recurrentes. + Prueba indicaciones como estas: + + - "Envíame cada mañana un resumen diario de los nuevos documentos en mi base de conocimiento." + - "Genera un informe de estado semanal a partir de mi Slack y Gmail cada viernes." + - "Ejecuta un informe mensual de análisis de la competencia y guárdalo en mi espacio de trabajo." + - "Resume mi actividad de GitHub y Linear en una actualización diaria de standup." + - "Crea un informe de investigación semanal recurrente sobre los temas que sigo." + + - Event-Triggered Automations: lanza un agente en el momento en que un documento llega a una carpeta y publica el resultado en tus herramientas. + Prueba indicaciones como estas: + + - "Cuando llegue un PDF a mi carpeta de Investigación, genera un resumen de IA con citas." + - "Cuando se añadan nuevas notas de reunión, conviértelas en actas con elementos de acción." + - "Cuando se suba una factura, extrae el proveedor, el total y la fecha de vencimiento en una tabla." + - "Cuando entre un contrato en mi carpeta Legal, señala los términos clave y las fechas de renovación." + - "Cuando se añada un currículum a Candidatos, evalúalo frente a la descripción del empleo." + + - Chat-Built Automations: describe una automatización en lenguaje sencillo y SurfSense la crea por ti. + Prueba indicaciones como estas: + + - "Crea un agente de IA que me envíe cada mañana un resumen de las nuevas páginas de Notion." + - "Crea una automatización sin código que publique un resumen de investigación semanal en Slack." + - "Configura un tomador de notas con IA que convierta las nuevas notas de reunión en actas." + - "Crea un flujo que extraiga los elementos de acción de las notas de reunión y asigne responsables." + - "Automatiza un resumen diario por correo a partir de mi Gmail y Google Drive." ### Auto-Hospedado @@ -199,6 +270,7 @@ Todas las funciones operan contra tu espacio de búsqueda elegido, por lo que tu | **Generación de Videos** | Resúmenes en video cinemáticos vía Veo 3 (solo Ultra) | Disponible (NotebookLM es mejor aquí, mejorando activamente) | | **Generación de Presentaciones** | Diapositivas más atractivas pero no editables | Crea presentaciones editables basadas en diapositivas | | **Generación de Podcasts** | Resúmenes de audio con hosts e idiomas personalizables | Disponible con múltiples proveedores TTS (NotebookLM es mejor aquí, mejorando activamente) | +| **Automatizaciones y Agentes de IA** | No | Flujos de trabajo de IA programados, disparadores por eventos en documentos nuevos y automatizaciones sin código creadas por chat con escritura de vuelta a Notion, Slack, Linear y Jira | | **Aplicación de Escritorio** | No | Aplicación nativa con General Assist, Quick Assist, Screenshot Assist y sincronización de carpetas locales | | **Extensión de Navegador** | No | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación | diff --git a/README.hi.md b/README.hi.md index 43e24c3ee..10b246385 100644 --- a/README.hi.md +++ b/README.hi.md @@ -41,6 +41,7 @@ NotebookLM वहाँ उपलब्ध सबसे अच्छे और - **कोई विक्रेता लॉक-इन नहीं** - किसी भी LLM, इमेज, TTS और STT मॉडल को कॉन्फ़िगर करें। - **25+ बाहरी डेटा स्रोत** - Google Drive, OneDrive, Dropbox, Notion और कई अन्य बाहरी सेवाओं से अपने स्रोत जोड़ें। - **रीयल-टाइम मल्टीप्लेयर सपोर्ट** - एक साझा notebook में अपनी टीम के सदस्यों के साथ आसानी से काम करें। +- **AI ऑटोमेशन और एजेंट** - AI एजेंट को शेड्यूल पर चलाएं या जैसे ही कोई दस्तावेज़ किसी फ़ोल्डर में आए उसे ट्रिगर करें, फिर परिणाम वापस Notion, Slack, Linear और Drive में लिखें। चैट में बस वर्णन करके बिना-कोड ऑटोमेशन बनाएं। - **डेस्कटॉप ऐप** - Quick Assist, General Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ किसी भी एप्लिकेशन में AI सहायता प्राप्त करें। ...और भी बहुत कुछ आने वाला है। @@ -76,48 +77,118 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7 4. सब कुछ इंडेक्स हो जाने के बाद, कुछ भी पूछें (उपयोग के मामले): - - डेस्कटॉप ऐप — General Assist + **डेस्कटॉप ऐप** (नीचे दी गई सभी सुविधाओं के अलावा नेटिव एक्स्ट्रा, कोई अलग सेट नहीं) + + - General Assist: किसी भी ऐप्लिकेशन से ग्लोबल शॉर्टकट के ज़रिए SurfSense तुरंत खोलें।

General Assist

- - डेस्कटॉप ऐप — Quick Assist + - Quick Assist: कहीं भी टेक्स्ट चुनें और AI से उसे समझाने, दोबारा लिखने या उस पर कार्रवाई करने को कहें।

Quick Assist

- - डेस्कटॉप ऐप — Screenshot Assist + - Screenshot Assist: अपनी स्क्रीन का कोई भी हिस्सा कैप्चर करें और AI से उसमें मौजूद चीज़ों के बारे में पूछें।

Screenshot Assist

- - डेस्कटॉप ऐप — Watch Local Folder + - Watch Local Folder: किसी लोकल फ़ोल्डर को अपने नॉलेज बेस के साथ अपने-आप सिंक करें। Obsidian vaults के लिए बढ़िया।

Watch Local Folder

- - वीडियो जनरेशन + **डिलीवरेबल स्टूडियो** -

वीडियो जनरेशन

+ - AI Report Generator: उद्धरण सहित रिसर्च रिपोर्ट बनाएं और PDF, DOCX, HTML, LaTeX, EPUB, ODT या सादे टेक्स्ट में एक्सपोर्ट करें। - - बेसिक सर्च और उद्धरण +

AI Report Generator

-

सर्च और उद्धरण

+ - AI Podcast Generator: किसी भी दस्तावेज़ या फ़ोल्डर को 20 सेकंड से भी कम में दो-होस्ट वाले AI पॉडकास्ट में बदलें। - - दस्तावेज़ मेंशन QNA +

AI Podcast Generator

-

दस्तावेज़ मेंशन QNA

-

दस्तावेज़ मेंशन QNA

+ - AI Presentation & Video Maker: अपने स्रोतों से एडिट करने योग्य स्लाइड डेक और नैरेटेड वीडियो बनाएं। - - रिपोर्ट जनरेशन और एक्सपोर्ट (PDF, DOCX, HTML, LaTeX, EPUB, ODT, सादा टेक्स्ट) +

AI Presentation and Video Maker

-

रिपोर्ट जनरेशन

+ - AI Image Generator: अपनी चैट और दस्तावेज़ों से सीधे उच्च-गुणवत्ता वाली इमेज बनाएं। - - पॉडकास्ट जनरेशन +

AI Image Generator

-

पॉडकास्ट जनरेशन

+ - AI Resume Builder: अपने मौजूदा रिज़्यूमे को किसी भी जॉब डिस्क्रिप्शन के अनुसार ढालें और ATS को पार करें। + इस तरह के प्रॉम्प्ट आज़माएं: - - इमेज जनरेशन + - "मेरे रिज़्यूमे को इस जॉब डिस्क्रिप्शन के अनुसार ढालें ताकि वह ATS पार करे और इंटरव्यू दिलाए।" + - "इस जॉब पोस्टिंग के कीवर्ड्स से मिलान करके मेरे रिज़्यूमे को ATS के लिए ऑप्टिमाइज़ करें।" + - "इस भूमिका के लिए ज़रूरी स्किल्स को उभारने के लिए मेरे रिज़्यूमे के बुलेट पॉइंट फिर से लिखें।" + - "मेरे रिज़्यूमे की तुलना इस जॉब डिस्क्रिप्शन से करें और सुधारने योग्य कमियों की सूची दें।" + - "मेरे रिज़्यूमे और इस जॉब डिस्क्रिप्शन से मेल खाता एक कवर लेटर लिखें।" -

इमेज जनरेशन

+ **सर्च और चैट** - - और भी बहुत कुछ जल्द आ रहा है। + - Chat With Your PDFs & Docs: अपनी सभी फ़ाइलों पर सवाल पूछें और इनलाइन उद्धरणों के साथ जवाब पाएं। + +

Chat With Your PDFs and Docs

+ + - AI Search With Citations: अपने पूरे नॉलेज बेस में हाइब्रिड सेमांटिक और कीवर्ड सर्च। + +

AI Search With Citations

+ + - Collaborative AI Chat: अपनी टीम के साथ रियल टाइम में AI बातचीत पर काम करें। + +

Collaborative AI Chat

+ + - Comments & Mentions: किसी भी AI संदेश पर टिप्पणी करें और टीम के साथियों को टैग करें। + +

Comments and Mentions

+ + **कनेक्टर्स और इंटीग्रेशन** + + - Connect & Sync Your Tools: Notion, Slack, Google Drive, Gmail, GitHub, Linear और 25+ स्रोतों को एक खोजने योग्य कॉर्पस में सिंक करें। + +

Connect and Sync Your Tools

+ + - Chat With Uploaded Files: PDF, Office दस्तावेज़, इमेज और ऑडियो अपलोड करें। तुरंत खोजने योग्य। + +

Chat With Uploaded Files

+ + - Connector Write-Back: एजेंट को परिणाम वापस Notion, Slack, Linear और Drive में पोस्ट करने दें। + इस तरह के प्रॉम्प्ट आज़माएं: + + - "इस रिसर्च सारांश को मेरे Notion वर्कस्पेस में पोस्ट करें।" + - "इन मीटिंग एक्शन आइटम्स को हमारे टीम Slack चैनल पर भेजें।" + - "इस बग रिपोर्ट से एक Jira टिकट बनाएं।" + - "इस फ़ीचर अनुरोध से Linear में एक इश्यू खोलें।" + - "इस जनरेट की गई रिपोर्ट को Google Drive में एक डॉक के रूप में सेव करें।" + + - Obsidian & Knowledge Base Sync: अपने Obsidian vault और व्यक्तिगत नॉलेज बेस को सिंक रखें। + + **ऑटोमेशन** + + - Scheduled AI Workflows: किसी एजेंट को शेड्यूल पर चलाएं: रोज़ाना ब्रीफ़, साप्ताहिक डाइजेस्ट, आवर्ती रिपोर्ट। + इस तरह के प्रॉम्प्ट आज़माएं: + + - "हर सुबह मेरे नॉलेज बेस में जुड़े नए दस्तावेज़ों का रोज़ाना ब्रीफ़ मुझे ईमेल करें।" + - "हर शुक्रवार मेरे Slack और Gmail से एक साप्ताहिक स्टेटस रिपोर्ट बनाएं।" + - "एक मासिक प्रतिस्पर्धी विश्लेषण रिपोर्ट चलाएं और उसे मेरे वर्कस्पेस में सेव करें।" + - "मेरी GitHub और Linear गतिविधि को एक रोज़ाना standup अपडेट में सारांशित करें।" + - "मैं जिन विषयों को ट्रैक करता हूं उन पर एक आवर्ती साप्ताहिक रिसर्च रिपोर्ट बनाएं।" + + - Event-Triggered Automations: जैसे ही कोई दस्तावेज़ किसी फ़ोल्डर में आता है, एजेंट को चलाएं और परिणाम अपने टूल में पोस्ट करें। + इस तरह के प्रॉम्प्ट आज़माएं: + + - "जब मेरे Research फ़ोल्डर में कोई PDF आए, तो उद्धरण सहित एक AI सारांश बनाएं।" + - "जब नई मीटिंग नोट्स जुड़ें, तो उन्हें एक्शन आइटम्स के साथ मीटिंग मिनट्स में बदलें।" + - "जब कोई इनवॉइस अपलोड हो, तो विक्रेता, कुल राशि और देय तिथि को एक तालिका में निकालें।" + - "जब मेरे Legal फ़ोल्डर में कोई अनुबंध आए, तो मुख्य शर्तों और नवीनीकरण तिथियों को चिह्नित करें।" + - "जब Candidates में कोई रिज़्यूमे जुड़े, तो उसे जॉब डिस्क्रिप्शन के विरुद्ध स्क्रीन करें।" + + - Chat-Built Automations: सरल भाषा में किसी ऑटोमेशन का वर्णन करें और SurfSense उसे आपके लिए बना देगा। + इस तरह के प्रॉम्प्ट आज़माएं: + + - "एक AI एजेंट बनाएं जो हर सुबह नई Notion पेजों का सारांश मुझे ईमेल करे।" + - "एक नो-कोड ऑटोमेशन बनाएं जो हर सप्ताह एक रिसर्च डाइजेस्ट Slack पर पोस्ट करे।" + - "एक AI नोट-टेकर सेट करें जो नई मीटिंग नोट्स को मिनट्स में बदल दे।" + - "एक वर्कफ़्लो बनाएं जो मीटिंग नोट्स से एक्शन आइटम्स निकाले और ज़िम्मेदार सौंपे।" + - "मेरे Gmail और Google Drive से एक रोज़ाना ईमेल ब्रीफ़ को ऑटोमेट करें।" ### सेल्फ-होस्टेड @@ -199,6 +270,7 @@ SurfSense एक डेस्कटॉप ऐप भी प्रदान क | **वीडियो जनरेशन** | Veo 3 के माध्यम से सिनेमैटिक वीडियो ओवरव्यू (केवल Ultra) | उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) | | **प्रेजेंटेशन जनरेशन** | बेहतर दिखने वाली स्लाइड्स लेकिन संपादन योग्य नहीं | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं | | **पॉडकास्ट जनरेशन** | कस्टमाइज़ेबल होस्ट और भाषाओं के साथ ऑडियो ओवरव्यू | कई TTS प्रदाताओं के साथ उपलब्ध (NotebookLM यहाँ बेहतर है, सक्रिय रूप से सुधार हो रहा है) | +| **AI ऑटोमेशन और एजेंट** | नहीं | शेड्यूल किए गए AI वर्कफ़्लो, नए दस्तावेज़ों पर इवेंट ट्रिगर, और चैट से बने बिना-कोड ऑटोमेशन, Notion, Slack, Linear और Jira में कनेक्टर राइट-बैक के साथ | | **डेस्कटॉप ऐप** | नहीं | General Assist, Quick Assist, Screenshot Assist और लोकल फ़ोल्डर सिंक के साथ नेटिव ऐप | | **ब्राउज़र एक्सटेंशन** | नहीं | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित | diff --git a/README.md b/README.md index ab9f9e221..a75122892 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ NotebookLM is one of the best and most useful AI platforms out there, but once y - **25+ External Data Sources** - Add your sources from Google Drive, OneDrive, Dropbox, Notion, and many other external services. - **Real-Time Multiplayer Support** - Work easily with your team members in a shared notebook. - **AI File Sorting** - Automatically organize your documents into a smart folder hierarchy using AI-powered categorization by source, date, and topic. +- **AI Automations & Agents** - Run AI agents on a schedule or trigger them the moment a document lands in a folder, then write results back to Notion, Slack, Linear, and Drive. Build no-code automations just by describing them in chat. - **Desktop App** - Get AI assistance in any application with Quick Assist, General Assist, Screenshot Assist, and local folder sync. ...and more to come. @@ -77,48 +78,118 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7 4. Once everything is indexed, Ask Away (Use Cases): - - Desktop App — General Assist + **Desktop App** (native extras on top of everything below, not a separate feature set) + + - General Assist: launch SurfSense instantly from any application with a global shortcut.

General Assist

- - Desktop App — Quick Assist + - Quick Assist: select text anywhere, then ask AI to explain, rewrite, or act on it.

Quick Assist

- - Desktop App — Screenshot Assist + - Screenshot Assist: capture any region of your screen and ask AI about what's in it.

Screenshot Assist

- - Desktop App — Watch Local Folder + - Watch Local Folder: auto-sync a local folder to your knowledge base. Great for Obsidian vaults.

Watch Local Folder

- - Video Generation + **Deliverable Studio** -

Video Generation

+ - AI Report Generator: generate cited research reports and export to PDF, DOCX, HTML, LaTeX, EPUB, ODT, or plain text. - - Basic search and citation +

AI Report Generator

-

Search and Citation

+ - AI Podcast Generator: turn any document or folder into a two-host AI podcast in under 20 seconds. - - Document Mention QNA +

AI Podcast Generator

-

Document Mention QNA

-

Document Mention QNA

+ - AI Presentation & Video Maker: create editable slide decks and narrated video overviews from your sources. - - Report Generations and Exports (PDF, DOCX, HTML, LaTeX, EPUB, ODT, Plain Text) +

AI Presentation and Video Maker

-

Report Generation

+ - AI Image Generator: generate high-quality images straight from your chats and documents. - - Podcast Generations +

AI Image Generator

-

Podcast Generation

+ - AI Resume Builder: tailor your existing resume to any job description and beat the ATS. + Try prompts like these: - - Image Generations + - "Tailor my resume to this job description so it gets past ATS and lands an interview." + - "Optimize my resume for ATS by matching the keywords in this job posting." + - "Rewrite my resume bullet points to highlight the skills this role is asking for." + - "Compare my resume against this job description and list the gaps to fix." + - "Write a matching cover letter from my resume and this job description." -

Image Generation

+ **Search & Chat** - - And more coming soon. + - Chat With Your PDFs & Docs: ask questions across all your files and get answers with inline citations. + +

Chat With Your PDFs and Docs

+ + - AI Search With Citations: hybrid semantic and keyword search across your entire knowledge base. + +

AI Search With Citations

+ + - Collaborative AI Chat: work on AI conversations with your team in real time. + +

Collaborative AI Chat

+ + - Comments & Mentions: comment and tag teammates on any AI message. + +

Comments and Mentions

+ + **Connectors & Integrations** + + - Connect & Sync Your Tools: sync Notion, Slack, Google Drive, Gmail, GitHub, Linear and 25+ sources into one searchable corpus. + +

Connect and Sync Your Tools

+ + - Chat With Uploaded Files: drop in PDFs, Office docs, images and audio. Instantly searchable. + +

Chat With Uploaded Files

+ + - Connector Write-Back: let the agent post results back to Notion, Slack, Linear and Drive. + Try prompts like these: + + - "Post this research summary to my Notion workspace." + - "Send these meeting action items to our team Slack channel." + - "Create a Jira ticket from this bug report." + - "Open a Linear issue from this feature request." + - "Save this generated report to Google Drive as a doc." + + - Obsidian & Knowledge Base Sync: keep your Obsidian vault and personal knowledge base in sync. + + **Automations** + + - Scheduled AI Workflows: run an agent on a schedule: daily briefs, weekly digests, recurring reports. + Try prompts like these: + + - "Email me a daily brief of new documents in my knowledge base every morning." + - "Generate a weekly status report from my Slack and Gmail every Friday." + - "Run a monthly competitor analysis report and save it to my workspace." + - "Summarize my GitHub and Linear activity into a daily standup update." + - "Create a recurring weekly research report on the topics I track." + + - Event-Triggered Automations: fire an agent the moment a document lands in a folder, then post the result to your tools. + Try prompts like these: + + - "When a PDF lands in my Research folder, generate a cited AI summary." + - "When new meeting notes are added, turn them into meeting minutes with action items." + - "When an invoice is uploaded, extract the vendor, total, and due date into a table." + - "When a contract enters my Legal folder, flag key terms and renewal dates." + - "When a resume is added to Candidates, screen it against the job description." + + - Chat-Built Automations: describe an automation in plain English and SurfSense builds it for you. + Try prompts like these: + + - "Build an AI agent that emails me a summary of new Notion pages each morning." + - "Create a no-code automation that posts a weekly research digest to Slack." + - "Set up an AI note taker that turns new meeting notes into minutes." + - "Make a workflow that extracts action items from meeting notes and assigns owners." + - "Automate a daily email brief from my Gmail and Google Drive." ### Self Hosted @@ -201,6 +272,7 @@ All features operate against your chosen search space, so your answers are alway | **Presentation Generation** | Better looking slides but not editable | Create editable, slide-based presentations | | **Podcast Generation** | Audio Overviews with customizable hosts and languages | Available with multiple TTS providers (NotebookLM is better here, actively improving) | | **AI File Sorting** | No | LLM-powered auto-categorization into source, date, category, and subcategory folders | +| **AI Automations & Agents** | No | Scheduled AI workflows, event triggers on new documents, and chat-built no-code automations with connector write-back to Notion, Slack, Linear & Jira | | **Desktop App** | No | Native app with General Assist, Quick Assist, Screenshot Assist, and local folder sync | | **Browser Extension** | No | Cross-browser extension to save any webpage, including auth-protected pages | diff --git a/README.pt-BR.md b/README.pt-BR.md index fcb004cd6..db77e5132 100644 --- a/README.pt-BR.md +++ b/README.pt-BR.md @@ -41,6 +41,7 @@ O NotebookLM é uma das melhores e mais úteis plataformas de IA disponíveis, m - **Sem Dependência de Fornecedor** - Configure qualquer modelo LLM, de imagem, TTS e STT. - **25+ Fontes de Dados Externas** - Adicione suas fontes do Google Drive, OneDrive, Dropbox, Notion e muitos outros serviços externos. - **Suporte Multiplayer em Tempo Real** - Trabalhe facilmente com os membros da sua equipe em um notebook compartilhado. +- **Automações e Agentes de IA** - Execute agentes de IA em uma programação ou dispare-os no momento em que um documento chega a uma pasta, e escreva os resultados de volta no Notion, Slack, Linear e Drive. Crie automações sem código apenas descrevendo-as no chat. - **Aplicativo Desktop** - Obtenha assistência de IA em qualquer aplicativo com Quick Assist, General Assist, Screenshot Assist e sincronização de pastas locais. ...e mais por vir. @@ -76,48 +77,118 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7 4. Quando tudo estiver indexado, pergunte o que quiser (Casos de uso): - - Aplicativo Desktop — General Assist + **Aplicativo Desktop** (extras nativos, além de tudo o que está abaixo, não um conjunto separado) + + - General Assist: abra o SurfSense instantaneamente de qualquer aplicativo com um atalho global.

General Assist

- - Aplicativo Desktop — Quick Assist + - Quick Assist: selecione um texto em qualquer lugar e peça à IA para explicar, reescrever ou agir sobre ele.

Quick Assist

- - Aplicativo Desktop — Screenshot Assist + - Screenshot Assist: capture qualquer região da tela e pergunte à IA sobre o que está nela.

Screenshot Assist

- - Aplicativo Desktop — Watch Local Folder + - Watch Local Folder: sincronize automaticamente uma pasta local com sua base de conhecimento. Ótimo para cofres do Obsidian.

Watch Local Folder

- - Geração de vídeos + **Estúdio de Entregáveis** -

Geração de Vídeos

+ - AI Report Generator: gere relatórios de pesquisa com citações e exporte para PDF, DOCX, HTML, LaTeX, EPUB, ODT ou texto simples. - - Busca básica e citações +

AI Report Generator

-

Busca e Citação

+ - AI Podcast Generator: transforme qualquer documento ou pasta em um podcast de IA com dois apresentadores em menos de 20 segundos. - - QNA com menção de documentos +

AI Podcast Generator

-

QNA com Menção de Documentos

-

QNA com Menção de Documentos

+ - AI Presentation & Video Maker: crie apresentações editáveis e vídeos narrados a partir das suas fontes. - - Geração de relatórios e exportações (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto simples) +

AI Presentation and Video Maker

-

Geração de Relatórios

+ - AI Image Generator: gere imagens de alta qualidade diretamente das suas conversas e documentos. - - Geração de podcasts +

AI Image Generator

-

Geração de Podcasts

+ - AI Resume Builder: adapte seu currículo atual a qualquer descrição de vaga e supere o ATS. + Experimente prompts como estes: - - Geração de imagens + - "Adapte meu currículo a esta descrição de vaga para passar pelo ATS e conseguir uma entrevista." + - "Otimize meu currículo para o ATS combinando as palavras-chave desta vaga." + - "Reescreva os tópicos do meu currículo para destacar as habilidades que esta vaga exige." + - "Compare meu currículo com esta descrição de vaga e liste as lacunas a corrigir." + - "Escreva uma carta de apresentação combinando com meu currículo e esta descrição de vaga." -

Geração de Imagens

+ **Busca e Chat** - - E mais em breve. + - Chat With Your PDFs & Docs: faça perguntas sobre todos os seus arquivos e receba respostas com citações inline. + +

Chat With Your PDFs and Docs

+ + - AI Search With Citations: busca híbrida semântica e por palavra-chave em toda a sua base de conhecimento. + +

AI Search With Citations

+ + - Collaborative AI Chat: trabalhe em conversas de IA com sua equipe em tempo real. + +

Collaborative AI Chat

+ + - Comments & Mentions: comente e marque colegas em qualquer mensagem de IA. + +

Comments and Mentions

+ + **Conectores e Integrações** + + - Connect & Sync Your Tools: sincronize Notion, Slack, Google Drive, Gmail, GitHub, Linear e mais de 25 fontes em um único acervo pesquisável. + +

Connect and Sync Your Tools

+ + - Chat With Uploaded Files: envie PDFs, documentos do Office, imagens e áudio. Pesquisáveis instantaneamente. + +

Chat With Uploaded Files

+ + - Connector Write-Back: deixe o agente publicar os resultados de volta no Notion, Slack, Linear e Drive. + Experimente prompts como estes: + + - "Publique este resumo de pesquisa no meu espaço do Notion." + - "Envie estes itens de ação da reunião para o nosso canal do Slack." + - "Crie um ticket no Jira a partir deste relatório de bug." + - "Abra uma issue no Linear a partir desta solicitação de funcionalidade." + - "Salve este relatório gerado no Google Drive como um documento." + + - Obsidian & Knowledge Base Sync: mantenha seu cofre do Obsidian e sua base de conhecimento pessoal sincronizados. + + **Automações** + + - Scheduled AI Workflows: execute um agente em uma programação: resumos diários, boletins semanais, relatórios recorrentes. + Experimente prompts como estes: + + - "Envie-me todas as manhãs um resumo diário dos novos documentos na minha base de conhecimento." + - "Gere um relatório de status semanal a partir do meu Slack e Gmail toda sexta-feira." + - "Execute um relatório mensal de análise da concorrência e salve-o no meu espaço de trabalho." + - "Resuma minha atividade no GitHub e Linear em uma atualização diária de standup." + - "Crie um relatório de pesquisa semanal recorrente sobre os temas que acompanho." + + - Event-Triggered Automations: dispare um agente no momento em que um documento chega a uma pasta e publique o resultado nas suas ferramentas. + Experimente prompts como estes: + + - "Quando um PDF chegar à minha pasta de Pesquisa, gere um resumo com IA e citações." + - "Quando novas notas de reunião forem adicionadas, transforme-as em atas com itens de ação." + - "Quando uma fatura for enviada, extraia o fornecedor, o total e a data de vencimento em uma tabela." + - "Quando um contrato entrar na minha pasta Jurídica, sinalize os termos-chave e as datas de renovação." + - "Quando um currículo for adicionado a Candidatos, avalie-o em relação à descrição da vaga." + + - Chat-Built Automations: descreva uma automação em linguagem simples e o SurfSense a cria para você. + Experimente prompts como estes: + + - "Crie um agente de IA que me envie todas as manhãs um resumo das novas páginas do Notion." + - "Crie uma automação sem código que publique um resumo de pesquisa semanal no Slack." + - "Configure um anotador com IA que transforme as novas notas de reunião em atas." + - "Crie um fluxo que extraia os itens de ação das notas de reunião e atribua responsáveis." + - "Automatize um resumo diário por e-mail a partir do meu Gmail e Google Drive." ### Auto-Hospedado @@ -199,6 +270,7 @@ Todos os recursos operam no espaço de busca escolhido, para que suas respostas | **Geração de Vídeos** | Visões gerais cinemáticas via Veo 3 (apenas Ultra) | Disponível (NotebookLM é melhor aqui, melhorando ativamente) | | **Geração de Apresentações** | Slides mais bonitos mas não editáveis | Cria apresentações editáveis baseadas em slides | | **Geração de Podcasts** | Visões gerais em áudio com hosts e idiomas personalizáveis | Disponível com múltiplos provedores TTS (NotebookLM é melhor aqui, melhorando ativamente) | +| **Automações e Agentes de IA** | Não | Fluxos de trabalho de IA agendados, gatilhos por eventos em novos documentos e automações sem código criadas por chat com escrita de volta no Notion, Slack, Linear e Jira | | **Aplicativo Desktop** | Não | Aplicativo nativo com General Assist, Quick Assist, Screenshot Assist e sincronização de pastas locais | | **Extensão de Navegador** | Não | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação | diff --git a/README.zh-CN.md b/README.zh-CN.md index a07f4afdc..d3d5330f6 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -41,6 +41,7 @@ NotebookLM 是目前最好、最实用的 AI 平台之一,但当你开始经 - **无供应商锁定** - 配置任何 LLM、图像、TTS 和 STT 模型。 - **25+ 外部数据源** - 从 Google Drive、OneDrive、Dropbox、Notion 和许多其他外部服务添加你的来源。 - **实时多人协作支持** - 在共享笔记本中轻松与团队成员协作。 +- **AI 自动化与智能体** - 按计划运行 AI 智能体,或在文档进入文件夹的那一刻触发它们,然后将结果回写到 Notion、Slack、Linear 和 Drive。只需在聊天中描述即可创建无代码自动化。 - **桌面应用** - 通过 Quick Assist、General Assist、Screenshot Assist 和本地文件夹同步在任何应用程序中获得 AI 助手。 ...更多功能即将推出。 @@ -76,48 +77,118 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7 4. 一切索引完成后,尽管提问(使用场景): - - 桌面应用 — General Assist + **桌面应用**(在以下所有功能之外的原生附加功能,并非独立的功能集) + + - General Assist:通过全局快捷键,从任意应用中即刻打开 SurfSense。

General Assist

- - 桌面应用 — Quick Assist + - Quick Assist:在任意位置选中文本,让 AI 解释、改写或对其执行操作。

Quick Assist

- - 桌面应用 — Screenshot Assist + - Screenshot Assist:截取屏幕上任意区域,并就其中内容向 AI 提问。

Screenshot Assist

- - 桌面应用 — Watch Local Folder + - Watch Local Folder:将本地文件夹自动同步到你的知识库。非常适合 Obsidian 库。

Watch Local Folder

- - 视频生成 + **成果工作室** -

视频生成

+ - AI Report Generator:生成带引用的研究报告,并导出为 PDF、DOCX、HTML、LaTeX、EPUB、ODT 或纯文本。 - - 基本搜索和引用 +

AI Report Generator

-

搜索和引用

+ - AI Podcast Generator:在 20 秒内将任意文档或文件夹转换为双主持人 AI 播客。 - - 文档提及问答 +

AI Podcast Generator

-

文档提及问答

-

文档提及问答

+ - AI Presentation & Video Maker:根据你的资料创建可编辑的幻灯片和带旁白的视频概览。 - - 报告生成和导出(PDF、DOCX、HTML、LaTeX、EPUB、ODT、纯文本) +

AI Presentation and Video Maker

-

报告生成

+ - AI Image Generator:直接从你的聊天和文档生成高质量图像。 - - 播客生成 +

AI Image Generator

-

播客生成

+ - AI Resume Builder:根据任意职位描述定制你现有的简历,顺利通过 ATS。 + 可以试试这样的提示: - - 图像生成 + - “根据这份职位描述定制我的简历,让它通过 ATS 并赢得面试。” + - “匹配这份招聘启事中的关键词,为 ATS 优化我的简历。” + - “重写我的简历要点,突出这个岗位所需要的技能。” + - “将我的简历与这份职位描述对比,列出需要改进的差距。” + - “根据我的简历和这份职位描述,写一封相匹配的求职信。” -

图像生成

+ **搜索与聊天** - - 更多功能即将推出。 + - Chat With Your PDFs & Docs:跨所有文件提问,并获得带内联引用的答案。 + +

Chat With Your PDFs and Docs

+ + - AI Search With Citations:在整个知识库中进行语义与关键词的混合搜索。 + +

AI Search With Citations

+ + - Collaborative AI Chat:与团队实时协作处理 AI 对话。 + +

Collaborative AI Chat

+ + - Comments & Mentions:在任意 AI 消息上评论并 @ 你的队友。 + +

Comments and Mentions

+ + **连接器与集成** + + - Connect & Sync Your Tools:将 Notion、Slack、Google Drive、Gmail、GitHub、Linear 等 25+ 数据源同步为一个可搜索的语料库。 + +

Connect and Sync Your Tools

+ + - Chat With Uploaded Files:上传 PDF、Office 文档、图像和音频。即刻可搜索。 + +

Chat With Uploaded Files

+ + - Connector Write-Back:让智能体将结果回写到 Notion、Slack、Linear 和 Drive。 + 可以试试这样的提示: + + - “把这份研究摘要发布到我的 Notion 工作区。” + - “把这些会议行动项发送到我们的团队 Slack 频道。” + - “根据这份缺陷报告创建一个 Jira 工单。” + - “根据这个功能需求在 Linear 中创建一个 issue。” + - “把这份生成的报告作为文档保存到 Google Drive。” + + - Obsidian & Knowledge Base Sync:让你的 Obsidian 库与个人知识库保持同步。 + + **自动化** + + - Scheduled AI Workflows:按计划运行智能体:每日简报、每周摘要、周期性报告。 + 可以试试这样的提示: + + - “每天早上把我知识库中新增文档的每日简报发邮件给我。” + - “每周五根据我的 Slack 和 Gmail 生成一份每周状态报告。” + - “每月运行一次竞争对手分析报告并保存到我的工作区。” + - “把我的 GitHub 和 Linear 活动汇总成一份每日站会更新。” + - “针对我关注的主题创建一份周期性的每周研究报告。” + + - Event-Triggered Automations:在文档进入文件夹的那一刻触发智能体,并将结果发布到你的工具中。 + 可以试试这样的提示: + + - “当一个 PDF 进入我的 Research 文件夹时,生成一份带引用的 AI 摘要。” + - “当新增会议记录时,把它整理成带行动项的会议纪要。” + - “当上传发票时,把供应商、总额和到期日提取到一张表格中。” + - “当一份合同进入我的 Legal 文件夹时,标记关键条款和续约日期。” + - “当一份简历加入 Candidates 时,根据职位描述对其进行筛选。” + + - Chat-Built Automations:用通俗的语言描述一个自动化,SurfSense 就会为你构建它。 + 可以试试这样的提示: + + - “创建一个 AI 智能体,每天早上把新增 Notion 页面的摘要发邮件给我。” + - “创建一个无代码自动化,每周把研究摘要发布到 Slack。” + - “设置一个 AI 笔记助手,把新增会议记录整理成纪要。” + - “创建一个工作流,从会议记录中提取行动项并指派负责人。” + - “自动化一份来自我的 Gmail 和 Google Drive 的每日邮件简报。” ### 自托管 @@ -199,6 +270,7 @@ SurfSense 还提供桌面应用,将 AI 助手带到您计算机上的每个应 | **视频生成** | 通过 Veo 3 的电影级视频概览(仅 Ultra) | 可用(NotebookLM 在此方面更好,正在积极改进) | | **演示文稿生成** | 更美观的幻灯片但不可编辑 | 创建可编辑的幻灯片式演示文稿 | | **播客生成** | 可自定义主持人和语言的音频概览 | 可用,支持多种 TTS 提供商(NotebookLM 在此方面更好,正在积极改进) | +| **AI 自动化与智能体** | 否 | 定时 AI 工作流、新文档的事件触发,以及通过聊天构建的无代码自动化,支持回写到 Notion、Slack、Linear 和 Jira | | **桌面应用** | 否 | 原生应用,包含 General Assist、Quick Assist、Screenshot Assist 和本地文件夹同步 | | **浏览器扩展** | 否 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 | diff --git a/VERSION b/VERSION index 2678ff8d6..c4475d3bb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.25 +0.0.26 diff --git a/docker/.env.example b/docker/.env.example index 4de35a5e9..748f03048 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -7,6 +7,9 @@ # SurfSense version (use "latest" or a specific version like "0.0.14") SURFSENSE_VERSION=latest +# Deployment environment: dev or production +SURFSENSE_ENV=production + # ------------------------------------------------------------------------------ # Core Settings # ------------------------------------------------------------------------------ @@ -304,6 +307,28 @@ STT_SERVICE=local/base # LANGSMITH_API_KEY= # LANGSMITH_PROJECT=surfsense +# OpenTelemetry traces and metrics. +# Enable the collector with: docker compose --profile observability up -d +# SURFSENSE_ENABLE_OTEL=true +# OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 +# OTEL_EXPORTER_OTLP_PROTOCOL=grpc +# OTEL_RESOURCE_ATTRIBUTES=service.namespace=surfsense +# +# Emergency kill switch. +# OTEL_SDK_DISABLED=true +# +# Grafana Cloud OTLP credentials. These are used only by the collector container. +# GRAFANA_CLOUD_OTLP_ENDPOINT=https://otlp-gateway-.grafana.net/otlp +# GRAFANA_CLOUD_INSTANCE_ID= +# GRAFANA_CLOUD_API_KEY= +# +# Optional host port overrides for the bundled OTel Collector. Only change +# these if the host already uses 4317/4318/13133; backend containers still use +# the internal Docker endpoint above. +# OTEL_GRPC_PORT=4317 +# OTEL_HTTP_PORT=4318 +# OTEL_HEALTH_PORT=13133 + # ------------------------------------------------------------------------------ # Advanced (optional) # ------------------------------------------------------------------------------ diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 53b8ea1a9..58cb7b42f 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -78,6 +78,15 @@ services: timeout: 5s retries: 5 + otel-lgtm: + image: grafana/otel-lgtm:latest + ports: + - "${OTEL_GRPC_PORT:-4317}:4317" + - "${OTEL_HTTP_PORT:-4318}:4318" + - "${OTEL_GRAFANA_PORT:-3001}:3000" + - "${OTEL_TEMPO_PORT:-3200}:3200" + restart: unless-stopped + searxng: image: searxng/searxng:2026.3.13-3c1f68c59 ports: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 82d77f826..06a3ac79a 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -61,6 +61,29 @@ services: timeout: 5s retries: 5 + otel-collector: + image: otel/opentelemetry-collector-contrib:0.152.1 + profiles: + - observability + command: ["--config=/etc/otelcol/config.yaml"] + volumes: + - ./otel-collector/config.yaml:/etc/otelcol/config.yaml:ro + environment: + GRAFANA_CLOUD_OTLP_ENDPOINT: ${GRAFANA_CLOUD_OTLP_ENDPOINT:-} + GRAFANA_CLOUD_INSTANCE_ID: ${GRAFANA_CLOUD_INSTANCE_ID:-} + GRAFANA_CLOUD_API_KEY: ${GRAFANA_CLOUD_API_KEY:-} + ports: + - "${OTEL_GRPC_PORT:-4317}:4317" + - "${OTEL_HTTP_PORT:-4318}:4318" + - "${OTEL_HEALTH_PORT:-13133}:13133" + mem_limit: 2g + restart: unless-stopped + healthcheck: + test: ["CMD", "/otelcol-contrib", "--version"] + interval: 30s + timeout: 5s + retries: 3 + searxng: image: searxng/searxng:2026.3.13-3c1f68c59 volumes: diff --git a/docker/otel-collector/config.yaml b/docker/otel-collector/config.yaml new file mode 100644 index 000000000..f495eff9b --- /dev/null +++ b/docker/otel-collector/config.yaml @@ -0,0 +1,81 @@ +extensions: + health_check: + endpoint: 0.0.0.0:13133 + basicauth/grafana_cloud: + client_auth: + username: ${env:GRAFANA_CLOUD_INSTANCE_ID} + password: ${env:GRAFANA_CLOUD_API_KEY} + +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 + +processors: + # Percentage limits are calculated against the collector container memory limit. + # Keep docker-compose.yml/Coolify memory limit set for predictability. + memory_limiter: + check_interval: 1s + limit_percentage: 80 + spike_limit_percentage: 25 + + attributes/scrub: + actions: + - key: http.request.header.authorization + action: delete + - key: http.request.header.cookie + action: delete + - key: db.statement + action: delete + + tail_sampling: + decision_wait: 10s + num_traces: 50000 + expected_new_traces_per_sec: 100 + policies: + - name: errors + type: status_code + status_code: + status_codes: [ERROR] + - name: slow-requests + type: latency + latency: + threshold_ms: 500 + - name: baseline + type: probabilistic + probabilistic: + sampling_percentage: 100 + + batch: + timeout: 5s + send_batch_size: 1024 + send_batch_max_size: 2048 + +exporters: + otlp_http/grafana_cloud: + endpoint: ${env:GRAFANA_CLOUD_OTLP_ENDPOINT} + auth: + authenticator: basicauth/grafana_cloud + sending_queue: + enabled: true + queue_size: 10000 + retry_on_failure: + enabled: true + initial_interval: 5s + max_interval: 30s + max_elapsed_time: 300s + +service: + extensions: [health_check, basicauth/grafana_cloud] + pipelines: + traces: + receivers: [otlp] + processors: [memory_limiter, attributes/scrub, tail_sampling, batch] + exporters: [otlp_http/grafana_cloud] + metrics: + receivers: [otlp] + processors: [memory_limiter, batch] + exporters: [otlp_http/grafana_cloud] diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 3d442973c..70cf687d8 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -1,5 +1,8 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense +# Deployment environment: dev or production +SURFSENSE_ENV=dev + #Celery Config CELERY_BROKER_URL=redis://localhost:6379/0 CELERY_RESULT_BACKEND=redis://localhost:6379/0 @@ -303,8 +306,16 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_BUSY_MUTEX=false # SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call -# Observability — OTel (also requires OTEL_EXPORTER_OTLP_ENDPOINT) +# Observability - OTel # SURFSENSE_ENABLE_OTEL=false +# OpenTelemetry - endpoint enables export; absent = no-op. +# Production should point at an OTel Collector. For local docker-compose.dev.yml, +# use http://otel-lgtm:4317 instead. +# OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 +# OTEL_EXPORTER_OTLP_PROTOCOL=grpc # or http/protobuf +# OTEL_RESOURCE_ATTRIBUTES=service.namespace=surfsense +# OTEL_METRIC_EXPORT_INTERVAL=300000 # ms; 5 minutes +# OTEL_SDK_DISABLED=true # emergency kill-switch # Skills + subagents # SURFSENSE_ENABLE_SKILLS=false @@ -346,3 +357,50 @@ LANGSMITH_PROJECT=surfsense # updates and deletes — the TTL only bounds staleness for bulk-import # paths that bypass the ORM. Set to 0 to disable the cache. # SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS=30 + +# ----------------------------------------------------------------------------- +# `task` boundary controls (Hermes-inspired improvements) +# ----------------------------------------------------------------------------- +# Wall-clock budget for a single ``task(subagent, ...)`` invocation in +# seconds. Subagents that run hot (slow image vendors, sluggish embedders, +# wedged MCP servers) would otherwise pin the orchestrator until the next +# checkpoint heartbeat fires. On timeout the runtime cancels the underlying +# coroutine and synthesizes a ToolMessage telling the orchestrator to treat +# the result as ``status=error``. Set to 0 to disable the cap entirely. +# Default: 300.0 +# SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS=300 + +# Batch-mode (``task(tasks=[...])``) concurrency cap and max batch size. +# Concurrency is enforced via an ``asyncio.Semaphore`` so a runaway fanout +# cannot starve unrelated subagents (each child still owns an LLM call and +# its own DB session). Max-size is a hard safety net for prompt-injection / +# runaway loops; the orchestrator rarely needs more than a handful of +# concurrent specialists. Set concurrency to 1 to effectively serialise +# batches without changing the schema. +# SURFSENSE_TASK_BATCH_CONCURRENCY=3 +# SURFSENSE_TASK_BATCH_MAX_SIZE=8 + +# Soft per-turn cap on cumulative ``task(...)`` invocations across all +# subagents. Once the sum of ``state['billable_calls']`` crosses this +# number, the runtime appends a one-shot warning ToolMessage telling the +# orchestrator to wrap up rather than launching more specialists. Tunable +# so heavy-research turns (15+ legitimate specialist calls) don't trip the +# alarm in production. Set to 0 to disable the warning entirely. +# SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD=15 + +# Per-workspace spawn-paused kill switch — set via Redis at runtime, not +# this env var. The env var below only disables the check itself (useful +# for local dev without Redis). To pause a workspace in production: +# redis-cli SET surfsense:spawn_paused: 1 EX 600 +# redis-cli DEL surfsense:spawn_paused: +# The check is fail-open: a Redis blip never blocks ``task(...)``. +# SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=false + +# Note on Celery-backed deliverables (generate_podcast, +# generate_video_presentation): these tools poll the artefact row until +# it reaches a terminal status — they do NOT use an internal wall-clock +# budget. The effective ceiling is SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS +# (above, default 300s) in multi-agent mode and the chat's HTTP / process +# lifetime in single-agent mode. If your podcasts or videos routinely +# exceed 5 minutes, raise SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS (or +# set it to 0 to disable that ceiling entirely). diff --git a/surfsense_backend/alembic/versions/144_add_automation_tables.py b/surfsense_backend/alembic/versions/144_add_automation_tables.py new file mode 100644 index 000000000..39f927417 --- /dev/null +++ b/surfsense_backend/alembic/versions/144_add_automation_tables.py @@ -0,0 +1,177 @@ +"""Add automation tables (automations, automation_triggers, automation_runs) + +Revision ID: 144 +Revises: 143 +Create Date: 2026-05-26 + +Adds the three tables that back the v1 automation engine, plus the +three PostgreSQL ENUM types they reference. Matches the SQLAlchemy +models under ``app.automations.persistence.models`` and the v1 data +model in ``automation-design-plan.md`` §9. + +v1 ships these three tables only. ``domain_events`` is deferred to +Phase 3 with the event trigger; ``mcp_connections`` / ``mcp_tools`` +are deferred to Phase 4 with the MCP integration. +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "144" +down_revision: str | None = "143" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ENUM types (PostgreSQL requires types created before tables that use them) + op.execute( + """ + CREATE TYPE automation_status AS ENUM ( + 'active', 'paused', 'archived' + ); + """ + ) + op.execute( + """ + CREATE TYPE automation_trigger_type AS ENUM ( + 'schedule', 'manual' + ); + """ + ) + op.execute( + """ + CREATE TYPE automation_run_status AS ENUM ( + 'pending', 'running', 'succeeded', 'failed', + 'cancelled', 'timed_out' + ); + """ + ) + + # automations — the editable, versioned automation definition + op.execute( + """ + CREATE TABLE automations ( + id SERIAL PRIMARY KEY, + search_space_id INTEGER NOT NULL + REFERENCES searchspaces(id) ON DELETE CASCADE, + created_by_user_id UUID + REFERENCES "user"(id) ON DELETE SET NULL, + name VARCHAR(200) NOT NULL, + description TEXT, + status automation_status NOT NULL DEFAULT 'active', + definition JSONB NOT NULL, + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + """ + ) + op.execute( + "CREATE INDEX ix_automations_search_space_id ON automations(search_space_id);" + ) + op.execute( + "CREATE INDEX ix_automations_created_by_user_id ON automations(created_by_user_id);" + ) + op.execute("CREATE INDEX ix_automations_status ON automations(status);") + op.execute("CREATE INDEX ix_automations_created_at ON automations(created_at);") + op.execute("CREATE INDEX ix_automations_updated_at ON automations(updated_at);") + + # automation_triggers — one row per (automation, trigger-instance) pair + op.execute( + """ + CREATE TABLE automation_triggers ( + id SERIAL PRIMARY KEY, + automation_id INTEGER NOT NULL + REFERENCES automations(id) ON DELETE CASCADE, + type automation_trigger_type NOT NULL, + params JSONB NOT NULL, + static_inputs JSONB NOT NULL DEFAULT '{}'::jsonb, + enabled BOOLEAN NOT NULL DEFAULT true, + last_fired_at TIMESTAMP WITH TIME ZONE, + next_fire_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + """ + ) + op.execute( + "CREATE INDEX ix_automation_triggers_automation_id ON automation_triggers(automation_id);" + ) + op.execute("CREATE INDEX ix_automation_triggers_type ON automation_triggers(type);") + op.execute( + "CREATE INDEX ix_automation_triggers_enabled ON automation_triggers(enabled);" + ) + op.execute( + "CREATE INDEX ix_automation_triggers_created_at ON automation_triggers(created_at);" + ) + # Partial index for the schedule tick: only enabled schedule triggers + # with a scheduled next fire are ever scanned for due rows. + op.execute( + """ + CREATE INDEX ix_automation_triggers_due + ON automation_triggers (next_fire_at) + WHERE enabled = true + AND type = 'schedule' + AND next_fire_at IS NOT NULL; + """ + ) + + # automation_runs — the immutable per-fire execution record + op.execute( + """ + CREATE TABLE automation_runs ( + id SERIAL PRIMARY KEY, + automation_id INTEGER NOT NULL + REFERENCES automations(id) ON DELETE CASCADE, + trigger_id INTEGER + REFERENCES automation_triggers(id) ON DELETE SET NULL, + status automation_run_status NOT NULL DEFAULT 'pending', + definition_snapshot JSONB NOT NULL, + inputs JSONB NOT NULL DEFAULT '{}'::jsonb, + step_results JSONB NOT NULL DEFAULT '[]'::jsonb, + output JSONB, + artifacts JSONB NOT NULL DEFAULT '[]'::jsonb, + error JSONB, + started_at TIMESTAMP WITH TIME ZONE, + finished_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + """ + ) + op.execute( + "CREATE INDEX ix_automation_runs_automation_id ON automation_runs(automation_id);" + ) + op.execute( + "CREATE INDEX ix_automation_runs_trigger_id ON automation_runs(trigger_id);" + ) + op.execute("CREATE INDEX ix_automation_runs_status ON automation_runs(status);") + op.execute( + "CREATE INDEX ix_automation_runs_created_at ON automation_runs(created_at);" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_automation_runs_created_at;") + op.execute("DROP INDEX IF EXISTS ix_automation_runs_status;") + op.execute("DROP INDEX IF EXISTS ix_automation_runs_trigger_id;") + op.execute("DROP INDEX IF EXISTS ix_automation_runs_automation_id;") + op.execute("DROP TABLE IF EXISTS automation_runs;") + + op.execute("DROP INDEX IF EXISTS ix_automation_triggers_due;") + op.execute("DROP INDEX IF EXISTS ix_automation_triggers_created_at;") + op.execute("DROP INDEX IF EXISTS ix_automation_triggers_enabled;") + op.execute("DROP INDEX IF EXISTS ix_automation_triggers_type;") + op.execute("DROP INDEX IF EXISTS ix_automation_triggers_automation_id;") + op.execute("DROP TABLE IF EXISTS automation_triggers;") + + op.execute("DROP INDEX IF EXISTS ix_automations_updated_at;") + op.execute("DROP INDEX IF EXISTS ix_automations_created_at;") + op.execute("DROP INDEX IF EXISTS ix_automations_status;") + op.execute("DROP INDEX IF EXISTS ix_automations_created_by_user_id;") + op.execute("DROP INDEX IF EXISTS ix_automations_search_space_id;") + op.execute("DROP TABLE IF EXISTS automations;") + + op.execute("DROP TYPE IF EXISTS automation_run_status;") + op.execute("DROP TYPE IF EXISTS automation_trigger_type;") + op.execute("DROP TYPE IF EXISTS automation_status;") diff --git a/surfsense_backend/alembic/versions/145_add_automations_permissions_to_roles.py b/surfsense_backend/alembic/versions/145_add_automations_permissions_to_roles.py new file mode 100644 index 000000000..779656b44 --- /dev/null +++ b/surfsense_backend/alembic/versions/145_add_automations_permissions_to_roles.py @@ -0,0 +1,87 @@ +"""Add automations permissions to existing Editor/Viewer roles + +Revision ID: 145 +Revises: 144 +Create Date: 2026-05-27 + +Owners already have ``*`` and need no backfill. Custom (non-system) roles +are left untouched on purpose: workspace admins manage those explicitly. +""" + +from collections.abc import Sequence + +from sqlalchemy import text + +from alembic import op + +revision: str = "145" +down_revision: str | None = "144" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +_EDITOR_PERMISSIONS = ( + "automations:create", + "automations:read", + "automations:update", + "automations:execute", +) +_VIEWER_PERMISSIONS = ("automations:read",) + + +def upgrade(): + connection = op.get_bind() + + for permission in _EDITOR_PERMISSIONS: + connection.execute( + text( + """ + UPDATE search_space_roles + SET permissions = array_append(permissions, :permission) + WHERE name = 'Editor' + AND NOT (:permission = ANY(permissions)) + """ + ), + {"permission": permission}, + ) + + for permission in _VIEWER_PERMISSIONS: + connection.execute( + text( + """ + UPDATE search_space_roles + SET permissions = array_append(permissions, :permission) + WHERE name = 'Viewer' + AND NOT (:permission = ANY(permissions)) + """ + ), + {"permission": permission}, + ) + + +def downgrade(): + connection = op.get_bind() + + for permission in _EDITOR_PERMISSIONS: + connection.execute( + text( + """ + UPDATE search_space_roles + SET permissions = array_remove(permissions, :permission) + WHERE name = 'Editor' + """ + ), + {"permission": permission}, + ) + + for permission in _VIEWER_PERMISSIONS: + connection.execute( + text( + """ + UPDATE search_space_roles + SET permissions = array_remove(permissions, :permission) + WHERE name = 'Viewer' + """ + ), + {"permission": permission}, + ) diff --git a/surfsense_backend/alembic/versions/146_drop_surfsense_docs_tables.py b/surfsense_backend/alembic/versions/146_drop_surfsense_docs_tables.py new file mode 100644 index 000000000..725405834 --- /dev/null +++ b/surfsense_backend/alembic/versions/146_drop_surfsense_docs_tables.py @@ -0,0 +1,129 @@ +"""Drop Surfsense docs tables (feature removed end to end) + +Revision ID: 146 +Revises: 145 +Create Date: 2026-05-28 + +Removes the SurfSense product-documentation feature: the +``surfsense_docs_documents`` and ``surfsense_docs_chunks`` tables (created +in revision 60) and the GIN trigram index on the title column (added in +revision 67). The docs were seeded at startup from local MDX files, so no +user data is lost. Downgrade recreates the tables and indexes. +""" + +from collections.abc import Sequence + +from alembic import op +from app.config import config + +# revision identifiers, used by Alembic. +revision: str = "146" +down_revision: str | None = "145" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +# Embedding dimension is required to recreate the vector columns on downgrade. +EMBEDDING_DIM = config.embedding_model_instance.dimension + + +def upgrade() -> None: + """Drop surfsense docs tables and all their indexes.""" + # Trigram index from revision 67 + op.execute("DROP INDEX IF EXISTS idx_surfsense_docs_title_trgm") + + # Full-text search indexes + op.execute("DROP INDEX IF EXISTS surfsense_docs_chunks_search_index") + op.execute("DROP INDEX IF EXISTS surfsense_docs_documents_search_index") + + # Vector indexes + op.execute("DROP INDEX IF EXISTS surfsense_docs_chunks_vector_index") + op.execute("DROP INDEX IF EXISTS surfsense_docs_documents_vector_index") + + # B-tree indexes + op.execute("DROP INDEX IF EXISTS ix_surfsense_docs_chunks_document_id") + op.execute("DROP INDEX IF EXISTS ix_surfsense_docs_documents_updated_at") + op.execute("DROP INDEX IF EXISTS ix_surfsense_docs_documents_content_hash") + op.execute("DROP INDEX IF EXISTS ix_surfsense_docs_documents_source") + + # Tables (chunks first due to FK) + op.execute("DROP TABLE IF EXISTS surfsense_docs_chunks") + op.execute("DROP TABLE IF EXISTS surfsense_docs_documents") + + +def downgrade() -> None: + """Recreate surfsense docs tables and indexes (reverses revisions 60 + 67).""" + op.execute( + f""" + CREATE TABLE IF NOT EXISTS surfsense_docs_documents ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + source VARCHAR NOT NULL UNIQUE, + title VARCHAR NOT NULL, + content TEXT NOT NULL, + content_hash VARCHAR NOT NULL, + embedding vector({EMBEDDING_DIM}), + updated_at TIMESTAMP WITH TIME ZONE + ); + """ + ) + op.execute( + f""" + CREATE TABLE IF NOT EXISTS surfsense_docs_chunks ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + content TEXT NOT NULL, + embedding vector({EMBEDDING_DIM}), + document_id INTEGER NOT NULL REFERENCES surfsense_docs_documents(id) ON DELETE CASCADE + ); + """ + ) + + # B-tree indexes + op.execute( + "CREATE INDEX IF NOT EXISTS ix_surfsense_docs_documents_source ON surfsense_docs_documents(source)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_surfsense_docs_documents_content_hash ON surfsense_docs_documents(content_hash)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_surfsense_docs_documents_updated_at ON surfsense_docs_documents(updated_at)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_surfsense_docs_chunks_document_id ON surfsense_docs_chunks(document_id)" + ) + + # Vector indexes + op.execute( + """ + CREATE INDEX IF NOT EXISTS surfsense_docs_documents_vector_index + ON surfsense_docs_documents USING hnsw (embedding public.vector_cosine_ops); + """ + ) + op.execute( + """ + CREATE INDEX IF NOT EXISTS surfsense_docs_chunks_vector_index + ON surfsense_docs_chunks USING hnsw (embedding public.vector_cosine_ops); + """ + ) + + # Full-text search indexes + op.execute( + """ + CREATE INDEX IF NOT EXISTS surfsense_docs_documents_search_index + ON surfsense_docs_documents USING gin (to_tsvector('english', content)); + """ + ) + op.execute( + """ + CREATE INDEX IF NOT EXISTS surfsense_docs_chunks_search_index + ON surfsense_docs_chunks USING gin (to_tsvector('english', content)); + """ + ) + + # Trigram index from revision 67 + op.execute( + """ + CREATE INDEX IF NOT EXISTS idx_surfsense_docs_title_trgm + ON surfsense_docs_documents USING gin (title gin_trgm_ops); + """ + ) diff --git a/surfsense_backend/alembic/versions/147_add_event_to_automation_trigger_type.py b/surfsense_backend/alembic/versions/147_add_event_to_automation_trigger_type.py new file mode 100644 index 000000000..32021a9d1 --- /dev/null +++ b/surfsense_backend/alembic/versions/147_add_event_to_automation_trigger_type.py @@ -0,0 +1,47 @@ +"""Add 'event' to automation_trigger_type enum + +Revision ID: 147 +Revises: 146 +Create Date: 2026-05-29 + +Adds the ``event`` value to the ``automation_trigger_type`` enum so automations +can be triggered by published domain events, alongside the existing +``schedule`` triggers. +""" + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "147" +down_revision: str | None = "146" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +ENUM_NAME = "automation_trigger_type" +NEW_VALUE = "event" + + +def upgrade() -> None: + """Safely add 'event' to automation_trigger_type enum if missing.""" + op.execute( + f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = '{ENUM_NAME}' AND e.enumlabel = '{NEW_VALUE}' + ) THEN + ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}'; + END IF; + END + $$; + """ + ) + + +def downgrade() -> None: + """No-op: PostgreSQL does not support removing enum values.""" + pass diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py index 03cf7acb8..df1ee1b4c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,6 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, + image_generation_config_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -91,7 +92,7 @@ async def build_agent_with_cache( # the key, otherwise a hit will leak state across threads. Bump the schema # version when the component list changes shape. cache_key = stable_hash( - "multi-agent-v1", + "multi-agent-v2", config_id, thread_id, user_id, @@ -109,6 +110,10 @@ async def build_agent_with_cache( system_prompt_hash(final_system_prompt), max_input_tokens, sorted(disabled_tools) if disabled_tools else None, + # Bound into the generate_image subagent tool at construction time, so it + # must key the compiled-agent cache to avoid leaking one automation's + # image model into another with the same config_id/search_space. + image_generation_config_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index 8451b3b7d..44529d243 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -62,8 +62,14 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, + image_generation_config_id: int | None = None, ): - """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.""" + """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. + + ``image_generation_config_id`` overrides the search space's image model for + this invocation (used by automations to run on their captured model). When + ``None``, the ``generate_image`` tool resolves the live search-space pref. + """ _t_agent_total = time.perf_counter() apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) @@ -129,6 +135,9 @@ async def create_multi_agent_chat_deep_agent( "available_document_types": available_document_types, "max_input_tokens": _max_input_tokens, "llm": llm, + # Per-invocation image model override (automations run on their captured + # model). Reaches the generate_image subagent tool via subagent_dependencies. + "image_generation_config_id_override": image_generation_config_id, } _t0 = time.perf_counter() @@ -285,6 +294,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, + image_generation_config_id_override=image_generation_config_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md index e61a0bffb..2abd95d5a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/citations/on.md @@ -4,8 +4,8 @@ never invent ids you didn't see. Citation ids are resolved by exact-match lookup; a wrong id silently breaks the link, so when in doubt, omit. ### Channel A — chunk blocks injected this turn -When `search_surfsense_docs` or `web_search` returns `` / -`` blocks in this turn: +When `web_search` returns `` / `` blocks in this +turn: 1. For each factual statement taken from those chunks, add `[citation:chunk_id]` using the **exact** id from a visible diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md index 71c86be40..8f2bfca4e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/private.md @@ -20,8 +20,8 @@ it to resolve paths the user describes in natural language ("my Q2 roadmap", delegating to a specialist. `` and `` blocks are chunked indexed content returned -by KB search (from `search_surfsense_docs`, or backing ``). -Each chunk carries a stable `id` attribute. +by KB search (backing ``). Each chunk carries a stable +`id` attribute. If a block doesn't appear this turn, work from the conversation alone. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md index 592c2ed9c..a5892c23a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/dynamic_context/team.md @@ -20,8 +20,8 @@ week's planning notes") into concrete document references before delegating to a specialist. `` and `` blocks are chunked indexed content returned -by KB search (from `search_surfsense_docs`, or backing ``). -Each chunk carries a stable `id` attribute. +by KB search (backing ``). Each chunk carries a stable +`id` attribute. If a block doesn't appear this turn, work from the conversation alone. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md index f06a52c1d..80fa4bf8f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/kb_first.md @@ -1,19 +1,21 @@ CRITICAL — ground factual answers in what you actually receive this turn: - injected workspace context (see ``), -- results from your own tool calls (`search_surfsense_docs`, `web_search`, - `scrape_webpage`), +- results from your own tool calls (`web_search`, `scrape_webpage`), - or substantive summaries returned by a `task` specialist you invoked. Do **not** answer factual or informational questions from general knowledge unless the user explicitly authorises it after you say you couldn't find enough in those sources. The flow when nothing is found: -1. Say you couldn't find enough in their workspace, docs, or tool output. +1. Say you couldn't find enough in their workspace or tool output. 2. Ask: *"Would you like me to answer from my general knowledge instead?"* 3. Only answer from general knowledge after a clear yes. This rule does NOT apply to: casual conversation · meta-questions about SurfSense ("what can you do?") · formatting or analysis of content already in chat · clear rewrite/edit instructions · lightweight web research. + +For "how do I use SurfSense" / product-documentation questions, point the +user to https://www.surfsense.com/docs. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md index 89154c443..d852f5955 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/anthropic.md @@ -5,7 +5,7 @@ Structured reasoning: - For non-trivial work, `` / short `` before tool calls is fine. Professional objectivity: -- Accuracy over flattery; verify with **search_surfsense_docs**, **web_search**, **scrape_webpage**, or **task** when unsure — don’t invent connector access. +- Accuracy over flattery; verify with **web_search**, **scrape_webpage**, or **task** when unsure — don’t invent connector access. Task management: - For 3+ steps, use todo tooling; update statuses promptly. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md index 4254e9ed5..01d56999f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/deepseek.md @@ -13,6 +13,6 @@ Attribution: Tool calls: - Parallelise independent calls. -- Prefer **search_surfsense_docs** for SurfSense docs/product questions before **web_search** when that fits the ask. +- For SurfSense docs/product questions, point the user to https://www.surfsense.com/docs. - Don’t invent paths, chunk ids, or URLs — only values from tools or the user. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md index dc5073538..32ed959c1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/google.md @@ -7,7 +7,7 @@ Output style: - GitHub-flavoured Markdown; monospace-friendly. Workflow (Understand → Plan → Act → Verify): -1. **Understand:** parse the ask; use **search_surfsense_docs** / injected workspace context before guessing. +1. **Understand:** parse the ask; use injected workspace context before guessing. 2. **Plan:** for multi-step work, a short plan first. 3. **Act:** only with tools you actually have on this agent (see `` and ``). Connector work → **task**. 4. **Verify:** re-read or re-search only when it materially reduces risk. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md index 7ff3ec912..8596c42cd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/providers/openai_classic.md @@ -15,6 +15,7 @@ Output style: Tool calls: - Parallelise independent calls in one turn. -- Prefer **search_surfsense_docs** for SurfSense-product questions, **web_search** / **scrape_webpage** - for fresh public facts; integrations and heavy workflows → **task**. +- For SurfSense-product questions, point the user to https://www.surfsense.com/docs; + use **web_search** / **scrape_webpage** for fresh public facts; integrations and + heavy workflows → **task**. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md index 4e27381d3..28cf0ac63 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md @@ -3,10 +3,7 @@ You have two execution channels. Pick the one that owns the work — never simulate one with the other. ### 1. Direct tools (you call them yourself) -- `search_surfsense_docs` — SurfSense product docs (setup, configuration, - connector docs, feature behavior). -- `web_search` — search the public web (anything outside SurfSense docs and - the workspace KB). +- `web_search` — search the public web (anything outside the workspace KB). - `scrape_webpage` — fetch the body of a specific public URL. - `update_memory` — curate persistent memory (see ``). - `write_todos` — maintain a structured plan when the turn series spans @@ -14,6 +11,10 @@ simulate one with the other. `in_progress` **before** the `task` call that handles it, `completed` once the call returns. Skip for single-step requests. +**Questions about how to use SurfSense itself** (setup, configuration, +connectors, feature behavior) — point the user to the documentation: +https://www.surfsense.com/docs. There is no docs-search tool; give the link. + **You have NO filesystem tools.** Any read, write, edit, move, rename, or search inside the user's workspace goes through `task(knowledge_base, …)` — never via `write_file`, `ls`, or any direct file operation. @@ -33,6 +34,15 @@ Rules for `task`: - Neither's prompt references the other's output, and - They target different specialists, OR the same specialist with non-overlapping scopes (e.g. reading two unrelated paths). +- **Batch shape for many-shot fanout.** When a single user request expands + to **3 or more independent specialist calls** (e.g. "create five issues + from this list"), prefer the batch shape: + `task(tasks=[{description, subagent_type}, ...])`. The runtime fans them + out concurrently under a small semaphore and aggregates one ToolMessage + per child prefixed with `[task ]`. Batched children **do not + support human-in-the-loop interrupts** — if one needs approval it surfaces + an error and you re-dispatch it as a single (non-batched) `task(...)` call. + For 1–2 independent calls, just emit two separate `task(...)` calls. - **Serialise dependent work across turns.** If one specialist's output must inform another's input (e.g. "find the roadmap in my KB, then email it to Maya"), invoke them on consecutive turns — first finishes, @@ -93,4 +103,65 @@ user: "Find my Q2 roadmap doc in the KB and email a summary to Maya." task(gmail, "Send an email to Maya with subject 'Q2 roadmap summary' and the following body: .") + + +user: "Create issues in Linear for each of these five bugs: " +→ Many-shot independent fanout — use the batch shape: + task(tasks=[ + {subagent_type: "linear", description: "Create a Linear issue titled + '' with body ''. Return the issue URL."}, + {subagent_type: "linear", description: "Create a Linear issue titled + '' with body ''. Return the issue URL."}, + {subagent_type: "linear", description: "Create a Linear issue titled + '' with body ''. Return the issue URL."}, + {subagent_type: "linear", description: "Create a Linear issue titled + '' with body ''. Return the issue URL."}, + {subagent_type: "linear", description: "Create a Linear issue titled + '' with body ''. Return the issue URL."}, + ]) + Read back the `[task 0]`…`[task 4]` blocks in the combined ToolMessage and + verify each via its Receipt's `verifiable_url` per the `` + teaching before confirming to the user. + + + +user: "Make a 30-second podcast of this conversation." +→ Celery-backed deliverable. The `deliverables` subagent dispatches the + Celery job and then **waits for it to finish** before returning. The + call may take 10-60 seconds (or longer for video presentations) — + that is intentional, not a hang. You always get back one of two + Receipt shapes: + task(deliverables, "Generate a podcast titled '' from the + following content. Use a 30-second style brief. Return the podcast + id and title.\n\n<source content>") + Outcomes: + - **`status="success"`**: the audio is saved. Tell the user the + podcast is **ready** and quote the `external_id` / `preview` so + they can find it in the podcast panel. + - **`status="failed"`**: surface the Receipt's `error` field + verbatim. Do NOT silently re-dispatch — the backend already tried + and reported a real error. + Same two-way pattern applies to video presentations (which take + longer to render, but still return a terminal status). If a + `task(deliverables, ...)` invocation itself times out at the subagent + layer (separate from the Receipt), that's an operator-side problem + with the subagent invoke timeout, not a deliverable failure — pass + the message through and stop. +</example> + +<example> +user: "Post the launch announcement to #general and let me know when it's up." +→ Mutating subagent + user wants external confirmation. Apply the + `<verification>` teaching: the slack subagent's reply is a self-report; + check its `evidence.receipts` for a Receipt with `status="success"` and + a `verifiable_url`, then fetch that URL to confirm before reporting back. + This turn: + task(slack, "Post '<launch announcement text>' to #general. + Return the message permalink.") + Next turn (with the receipt's `verifiable_url` in hand): + scrape_webpage(url=<verifiable_url from slack receipt>) + → confirm the post is live, then tell the user it's up with the URL. + If the slack reply has NO Receipt with `status="success"`, treat it as a + silent failure: surface the error verbatim, do not retry. +</example> </routing> diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py new file mode 100644 index 000000000..30699a4a1 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/__init__.py @@ -0,0 +1 @@ +"""``create_automation`` — description + few-shot examples.""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md new file mode 100644 index 000000000..ce6562c97 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/description.md @@ -0,0 +1,34 @@ +- `create_automation` — Draft and author a new automation. You describe the + user's intent; a focused drafter inside the tool turns it into the full + automation JSON; the user sees a preview on an approval card and chooses + approve or reject. All three phases happen in a single tool call. + - Call when the user wants SurfSense to do something on its own: anything + recurring or scheduled ("every morning…", "each Monday…", "weekly + recap…"). + - Args: + - `intent` (string): restate the user's request **concretely**, in one + paragraph. Cover three things: + - **What** should run (the action: summarize, recap, post, draft, …). + - **When** it should run (schedule + timezone if the user mentioned one; + otherwise leave the timezone for the drafter to default to UTC). + - **Static values** the automation needs (folder ids, channel names, + project keys, parent page ids, …) — list them with their values. + If the user did NOT supply one the automation needs, say so + explicitly ("the Notion parent page id was not specified") so the + drafter leaves a placeholder. + - Do NOT prompt the user to confirm before calling — the approval card + IS the confirmation. The card shows a structured preview plus the raw + JSON; it offers approve/reject only. If the user wants changes after + seeing the draft, they reply in chat and you call this tool again with + a refined `intent` — that's the edit path. + - Returns: + - `{status: "saved", automation_id, name}` — confirm briefly to the + user ("Saved as automation #N — runs <when>."). Don't dump JSON back. + - `{status: "rejected", message}` — the user declined on the card. + Acknowledge once ("Understood, I didn't create it.") and stop. Do + NOT retry or pitch variants without a fresh user request. + - `{status: "invalid", issues, raw?}` — drafting/validation failed + before the card was shown. Read the issues, refine your `intent` + with the missing details, call again. + - `{status: "error", message}` — surface the message verbatim and + offer to retry. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md new file mode 100644 index 000000000..19311bef0 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/create_automation/example.md @@ -0,0 +1,13 @@ +<example> +user: "Every weekday at 9am, summarize new documents in folder 12 and post the summary to Slack channel #daily-digest." +→ create_automation(intent="Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'.") +tool returns: {"status": "saved", "automation_id": 42, "name": "Daily folder 12 digest"} +(Reply briefly: "Saved as automation #42 — runs weekdays at 9am UTC.") +</example> + +<example> +user: "Once a week on Mondays at 7am Paris time, draft a Notion page recapping last week's Jira tickets in project CORE." +→ create_automation(intent="Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify which Notion page the recap should sit under — leave notion_parent_page_id as a placeholder.") +tool returns: {"status": "saved", "automation_id": 51, "name": "Weekly CORE Jira recap"} +(Reply: "Saved as automation #51. I left the Notion parent page id as a placeholder — set it on the automation before next Monday.") +</example> diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/__init__.py deleted file mode 100644 index c2cda318e..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""``search_surfsense_docs`` — description + few-shot examples.""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/description.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/description.md deleted file mode 100644 index 256d3f3a4..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/description.md +++ /dev/null @@ -1,10 +0,0 @@ -- `search_surfsense_docs` — Search official SurfSense documentation (product - help). - - Use when the user asks how SurfSense itself works — setup, configuration, - connector documentation, feature behavior, anything covered in the - product docs. - - Not a substitute for `task` when the user wants actions inside a - connected service (Gmail, Slack, Jira, Notion, etc.). - - Args: `query`, `top_k` (default 10). - - Returns doc excerpts; chunk ids may appear for attribution — see - `<citations>` for the contract. diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/example.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/example.md deleted file mode 100644 index d53ad8c91..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/search_surfsense_docs/example.md +++ /dev/null @@ -1,15 +0,0 @@ -<example> -user: "How do I install SurfSense?" -→ search_surfsense_docs(query="installation setup") -</example> - -<example> -user: "What connectors does SurfSense support?" -→ search_surfsense_docs(query="available connectors integrations") -</example> - -<example> -user: "How do I set up the Notion connector?" -→ search_surfsense_docs(query="Notion connector setup configuration") -(Changing data inside Notion itself → `task(notion, …)`, not this tool.) -</example> diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md index 2f47d4df1..d6a81d8d3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/tools/task/description.md @@ -4,12 +4,69 @@ `<specialists>` for the live roster. - Each subagent runs in isolation with its own tool stack and context, and returns a single synthesized result. - - Args: + - Args (single mode): - `subagent_type` — name of the specialist to invoke (must match an entry in `<specialists>`). - `description` — the FULL task prompt. The specialist cannot see this thread, so include all context and constraints, plus what you need back. The specialist will respond in its own format — don't dictate one. + - Args (batch mode): + - `tasks` — array of `{description, subagent_type}` objects to fan out + concurrently. Mutually exclusive with single-mode args. Use when a + single request expands to **3 or more independent specialist calls** + (e.g. "create five issues from this list"). Children run under a + small concurrency cap and the runtime returns one ToolMessage block + per child, prefixed with `[task <index>]`. **Batched children do not + support human-in-the-loop interrupts** — if any child needs approval + it surfaces an error and you must re-dispatch that single task as a + non-batched `task(...)` call. - Routing rules (when to call, how often, how to scope) live in `<routing>`. + <verification> + A subagent's natural-language reply is a **self-report**, not proof. The + specialist might claim a Slack message was posted, a Jira issue was + created, or a report was generated even when the underlying tool call + failed silently or was rate-limited. Treat success language ("Done", + "Posted to #general", "Created ENG-42") as a hypothesis, not a fact. + + Two ground-truth signals are always available to verify a mutating + subagent's claim: + + 1. **`state['receipts']`** — every mutating tool emits a structured + `Receipt` (route, type, operation, status, external_id, + verifiable_url, preview) into this append-only list. The supervisor + never sees the raw list directly, but each subagent's + `<output_contract>` carries the matching Receipt(s) under + `evidence.receipts`. If a subagent reports success with NO matching + Receipt at `status="success"` (or `"pending"` for async deliverables + like podcasts/videos), the operation did not happen — treat as + failure and surface that to the user verbatim, do not retry blindly. + + 2. **`scrape_webpage`** — when a Receipt carries a `verifiable_url` + (Notion page URL, Slack permalink, Jira issue URL, Linear identifier + URL, etc.), you can fetch that URL and confirm the operation + externally. Use this for high-stakes mutations the user explicitly + called out (e.g. "send the launch email to the whole team") or when + the subagent's self-report contradicts what the user expected. + + **Receipt status semantics — read carefully:** + + - `status="success"`: the mutation already committed in the backend. + If a `verifiable_url` is present and the request was high-stakes, + you may `scrape_webpage` it to externally confirm. Otherwise trust + the Receipt and tell the user it is done. Celery-backed deliverables + (podcasts, video presentations) also land here — the subagent + already waited for the worker to finish, so a `success` Receipt + means the artefact really is saved. + - `status="failed"`: a Receipt with this status carries the backend's + error in its `error` field. Surface that text verbatim to the user; + re-routing or retrying is only appropriate when the user explicitly + asks for it. + - `status="pending"`: rare today — current mutating tools wait for + their backend before returning. If you ever do see a pending + Receipt, tell the user the work has been **kicked off** (quote the + `external_id` / `preview` so they can find it later), do not + `scrape_webpage` it, and do not re-dispatch the same + `task(...)` call hoping it will be done "this time". + </verification> diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/__init__.py new file mode 100644 index 000000000..d47bbac7e --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/__init__.py @@ -0,0 +1,7 @@ +"""``create_automation`` — author + persist an automation via a HITL card.""" + +from __future__ import annotations + +from .create import create_create_automation_tool + +__all__ = ["create_create_automation_tool"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py new file mode 100644 index 000000000..62d39fcf2 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py @@ -0,0 +1,214 @@ +"""``create_automation`` — NL intent → drafted JSON → HITL approval card → persisted. + +Single tool that: + +1. Drafts a structured automation from the user's intent via a focused sub-LLM + (system prompt in :mod:`.prompt`). +2. Surfaces the validated draft in a HITL approval card + (``action_type="automation_create"``). +3. On approval, validates the (possibly edited) payload again and persists + it via :class:`AutomationService`. + +The main agent only restates the user's request as a single ``intent`` string. +The drafting sub-LLM owns the JSON shape; the HITL card is the user's review. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any +from uuid import UUID + +from fastapi import HTTPException +from langchain.tools import ToolRuntime +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool +from pydantic import ValidationError + +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) +from app.automations.schemas.api import AutomationCreate +from app.automations.services.automation import AutomationService +from app.db import User, async_session_maker +from app.utils.content_utils import extract_text_content + +from .prompt import build_draft_prompt + +logger = logging.getLogger(__name__) + +_JSON_FENCE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL) + + +def create_create_automation_tool( + *, + search_space_id: int, + user_id: str | UUID, + llm: Any, +): + """Factory for the ``create_automation`` tool. + + ``search_space_id`` is injected from the chat session (the model never + has to guess it). ``llm`` is the drafting sub-model — we reuse the main + agent's LLM and tag the call so it's identifiable in traces. A fresh + ``AsyncSession`` is opened per call to avoid stale sessions on + compiled-agent cache hits (same pattern as the Notion / memory tools). + """ + uid = UUID(user_id) if isinstance(user_id, str) else user_id + + @tool + async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]: + """Draft + save an automation from a natural-language intent. + + Use this when the user wants SurfSense to do something on its own + on a schedule (e.g. "every morning summarize folder 12 to Slack"). + Restate the user's request as ONE concrete ``intent`` string: what + should run, when, and which static values (folder ids, channel + names, …) it needs. + + The tool drafts the full automation JSON internally, shows the user + a structured preview on an approval card, and persists on approval. + The card supports approve/reject only — if the user wants edits + after seeing the draft, they say so in chat and you call this tool + again with a refined intent. Do NOT prompt the user to confirm + before calling — the card IS the confirmation. + + Args: + intent: Concrete restatement of the user's request. Include + the schedule (with timezone if mentioned), the action to + take, and any static values. Example: "Every weekday at + 09:00 UTC, summarize new docs added to folder_id=12 since + the last run, then post the summary to Slack channel + '#daily-digest'." + + Returns: + ``{"status": "saved", "automation_id": int, "name": str}`` on + approval + save. + ``{"status": "rejected", "message": "..."}`` when the user + declines on the card. + ``{"status": "invalid", "issues": [...], "raw": ...}`` when + the drafter produced output that did not validate (call again + with a more precise intent). + ``{"status": "error", "message": "..."}`` on drafter or + persistence failure. + + IMPORTANT: when status is ``"rejected"`` the user explicitly + declined. Acknowledge once and stop — do NOT retry or pitch + variants without a fresh user request. + """ + # Models are chosen per-automation on the approval card (premium/BYOK + # selectors) and validated when persisted by ``AutomationService.create`` + # — so there's no fail-fast search-space eligibility gate here. The + # search space's current chat/role model selection no longer constrains + # whether an automation can be drafted or saved. + + # --- 1. Draft via sub-LLM --- + prompt = build_draft_prompt(search_space_id=search_space_id, intent=intent) + try: + response = await llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal", "automation-draft"]}, + ) + except Exception as exc: + logger.exception("create_automation drafting LLM call failed") + return {"status": "error", "message": f"drafting failed: {exc}"} + + raw_text = extract_text_content(response.content).strip() + draft = _extract_json(raw_text) + if draft is None: + return { + "status": "invalid", + "issues": ["model output was not parseable JSON"], + "raw": raw_text, + } + + # search_space_id is injected here so the sub-LLM never has to guess. + draft["search_space_id"] = search_space_id + try: + validated_draft = AutomationCreate.model_validate(draft) + except ValidationError as exc: + return { + "status": "invalid", + "issues": _format_validation_issues(exc), + "raw": draft, + } + + # --- 2. HITL approval card --- + try: + card_params = validated_draft.model_dump(mode="json", by_alias=True) + # search_space_id is session-scoped, not user-editable. + card_params.pop("search_space_id", None) + + result = request_approval( + action_type="automation_create", + tool_name="create_automation", + params=card_params, + context={"search_space_id": search_space_id}, + tool_call_id=runtime.tool_call_id, + ) + + if result.rejected: + return { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + } + + # --- 3. Persist (re-validate in case the user edited) --- + final_payload = {**result.params, "search_space_id": search_space_id} + try: + final_validated = AutomationCreate.model_validate(final_payload) + except ValidationError as exc: + return { + "status": "invalid", + "issues": _format_validation_issues(exc), + } + + async with async_session_maker() as session: + user = await session.get(User, uid) + if user is None: + return { + "status": "error", + "message": "user not found in this session", + } + service = AutomationService(session=session, user=user) + created = await service.create(final_validated) + return { + "status": "saved", + "automation_id": created.id, + "name": created.name, + } + + except HTTPException as exc: + return {"status": "error", "message": exc.detail} + except Exception as exc: + from langgraph.errors import GraphInterrupt + + if isinstance(exc, GraphInterrupt): + raise + logger.exception("create_automation failed") + return {"status": "error", "message": f"persistence failed: {exc}"} + + return create_automation + + +def _extract_json(text: str) -> dict[str, Any] | None: + """Pull a JSON object out of the model response, tolerating ``` fences.""" + if not text: + return None + candidate = text + fence_match = _JSON_FENCE.search(text) + if fence_match: + candidate = fence_match.group(1) + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + return None + return parsed if isinstance(parsed, dict) else None + + +def _format_validation_issues(exc: ValidationError) -> list[str]: + return [ + f"{'.'.join(str(p) for p in err['loc'])}: {err['msg']}" for err in exc.errors() + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/prompt.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/prompt.py new file mode 100644 index 000000000..09854aa2e --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/prompt.py @@ -0,0 +1,178 @@ +"""System prompt for the drafting sub-LLM inside ``create_automation``. + +Converts a natural-language ``intent`` into a structured ``AutomationCreate`` +JSON object. That object becomes the payload the HITL approval card surfaces. + +Scope split: + Real automation JSONs live here — this is the graph that *generates* + the JSON. The main agent's prompt fragments (``description.md`` / + ``example.md``) only carry intent-string examples; the main agent + never sees the schema. + +Layout: + The prompt is concatenated from four format-safe pieces. ``_HEADER`` / + ``_FOOTER`` carry the only ``str.format`` placeholders; ``_SCHEMA`` and + ``_FEW_SHOTS`` are plain strings so their JSON literals (and the + ``{{ inputs.X }}`` Jinja references in queries) can stay readable + without doubled-brace escaping. + +Catalog handling: + v1 hard-codes the action/trigger catalog (one action, one trigger). + When new types ship, swap the inline lines for a render-time pull + from ``app.automations.actions`` / ``app.automations.triggers`` via + lazy imports inside :func:`build_draft_prompt` so this module never + participates in the ``multi_agent_chat`` import cycle. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +_HEADER = """\ +You are the SurfSense automation drafter. Convert the user intent below +into a SINGLE JSON object matching the AutomationCreate schema. Output +ONLY that JSON object — no prose, no markdown fence, no commentary. + +Current UTC time (for cron context): {now} +Target search_space_id: {search_space_id} +""" + + +_SCHEMA = """ +Required JSON shape: +{ + "name": "<1-200 char identifier>", + "description": "<one-liner or null>", + "definition": { + "schema_version": "1.0", + "name": "<same as outer name>", + "goal": "<one sentence>", + "plan": [ + { + "step_id": "<slug>", + "action": "agent_task", + "params": { + "query": "<Jinja string referencing {{ inputs.X }}>", + "auto_approve_all": true + } + } + ], + "metadata": {"tags": ["..."]} + }, + "triggers": [ + { + "type": "schedule", + "params": {"cron": "<5-field cron>", "timezone": "<IANA tz, default UTC>"}, + "static_inputs": {"<key>": <value>, ...}, + "enabled": true + } + ] +} + +v1 catalog (only these are valid): +- Actions: agent_task — params: query (string, Jinja), auto_approve_all (bool). +- Triggers: schedule — params: cron (5-field), timezone (IANA, e.g. "UTC", + "Europe/Paris"). Has static_inputs (object). + +Conventions: +- Whatever the plan references via {{ inputs.X }} MUST appear either in a + trigger's static_inputs OR in definition.inputs.schema_.properties so the + executor can resolve it at fire time. +- static_inputs carries values that stay the same across every fire + (folder ids, channel names, project keys, parent page ids). Put them on + the trigger that supplies them, not in the plan. +- If the user did NOT supply a value the plan needs, put "REPLACE_ME" in + static_inputs. Do NOT invent ids, channels, or paths. +- Cron is 5-field (minute hour day-of-month month day-of-week). Use the + timezone the user mentioned; default "UTC" when unspecified. +- Templating variables available at fire time: inputs.* (merged + static_inputs + runtime), inputs.fired_at, inputs.last_fired_at. +""" + + +_FEW_SHOTS = """ +Few-shot examples (intent → JSON output): + +### Example 1 — schedule with all static values supplied +intent: "Every weekday at 09:00 UTC, summarize documents added to folder_id=12 since the last run, then post the summary to Slack channel '#daily-digest'. Static inputs: folder_id=12, slack_channel='#daily-digest'." +output: +{ + "name": "Daily folder 12 digest", + "description": "Weekday 09:00 UTC summary of folder 12 documents posted to #daily-digest", + "definition": { + "schema_version": "1.0", + "name": "Daily folder 12 digest", + "goal": "Summarize new docs in folder 12 since the last run and post to #daily-digest", + "plan": [ + { + "step_id": "summarize_and_post", + "action": "agent_task", + "params": { + "query": "Summarize documents added to folder {{ inputs.folder_id }} since {{ inputs.last_fired_at or 'yesterday' }}, then send the summary to Slack channel {{ inputs.slack_channel }}.", + "auto_approve_all": true + } + } + ], + "metadata": {"tags": ["daily", "digest", "slack"]} + }, + "triggers": [ + { + "type": "schedule", + "params": {"cron": "0 9 * * 1-5", "timezone": "UTC"}, + "static_inputs": {"folder_id": 12, "slack_channel": "#daily-digest"}, + "enabled": true + } + ] +} + +### Example 2 — schedule with a missing value (REPLACE_ME placeholder) +intent: "Every Monday at 07:00 Europe/Paris, read last week's Jira issues in project CORE, then draft a Notion page recapping them. Static inputs: jira_project_key='CORE'. The user did NOT specify the Notion parent page id — leave it as a placeholder." +output: +{ + "name": "Weekly CORE Jira recap", + "description": "Monday 07:00 Europe/Paris recap of last week's CORE Jira issues, drafted to Notion", + "definition": { + "schema_version": "1.0", + "name": "Weekly CORE Jira recap", + "goal": "Recap last week's CORE Jira issues into a Notion page", + "plan": [ + { + "step_id": "recap", + "action": "agent_task", + "params": { + "query": "List Jira issues in project {{ inputs.jira_project_key }} updated in the 7 days before {{ inputs.fired_at }}. Draft a Notion page under parent id {{ inputs.notion_parent_page_id }} titled 'CORE recap — week of {{ inputs.fired_at }}'.", + "auto_approve_all": true + } + } + ], + "metadata": {"tags": ["weekly", "recap", "jira", "notion"]} + }, + "triggers": [ + { + "type": "schedule", + "params": {"cron": "0 7 * * 1", "timezone": "Europe/Paris"}, + "static_inputs": {"jira_project_key": "CORE", "notion_parent_page_id": "REPLACE_ME"}, + "enabled": true + } + ] +} +""" + + +_FOOTER = """ +User intent: +{intent} +""" + + +def build_draft_prompt(*, search_space_id: int, intent: str) -> str: + """Render the drafting sub-LLM system prompt for the given intent.""" + return ( + _HEADER.format( + now=datetime.now(UTC).isoformat(timespec="seconds"), + search_space_id=search_space_id, + ) + + _SCHEMA + + _FEW_SHOTS + + _FOOTER.format(intent=intent.strip()) + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py index 5d309261c..70fb42c0d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/index.py @@ -6,10 +6,10 @@ Connector integrations, MCP, deliverables, etc. are delegated via ``task`` subag from __future__ import annotations MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED: tuple[str, ...] = ( - "search_surfsense_docs", "web_search", "scrape_webpage", "update_memory", + "create_automation", ) MAIN_AGENT_SURFSENSE_TOOL_NAMES: frozenset[str] = frozenset( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py index 6c4519f3a..e11f3c3ec 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py @@ -2,6 +2,8 @@ from __future__ import annotations +import os + # Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS. EXCLUDED_STATE_KEYS = frozenset( { @@ -16,3 +18,72 @@ EXCLUDED_STATE_KEYS = frozenset( # Match the parent graph's budget; the LangGraph default of 25 trips on # multi-step subagent runs. DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000 + + +def _read_timeout_env(name: str, default: float) -> float: + """Parse ``name`` from the environment; fall back to ``default`` on bad values. + + Kept as a free function so the module-level constants stay constants + after import; tests can monkeypatch this and re-evaluate via + ``importlib.reload`` if they need a different value mid-process. + """ + raw = os.environ.get(name) + if not raw: + return default + try: + value = float(raw) + except (TypeError, ValueError): + return default + return value if value > 0 else default + + +# Wall-clock budget for a single ``task(subagent, ...)`` invocation. +# Subagents that run hot (image generation with slow vendors, KB writes +# behind a sluggish embedder) can otherwise wedge the orchestrator until +# the next checkpoint heartbeat. ``0`` disables the timeout entirely. +DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS: float = _read_timeout_env( + "SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS", + default=300.0, +) + + +def _read_int_env(name: str, default: int) -> int: + raw = os.environ.get(name) + if not raw: + return default + try: + value = int(raw) + except (TypeError, ValueError): + return default + return value if value > 0 else default + + +# Maximum number of children that ``task(..., tasks=[...])`` runs in +# parallel via ``asyncio.gather`` + ``Semaphore``. Bounded so a runaway +# fanout cannot starve unrelated subagents (each child still owns an +# LLM call + DB session). Set ``SURFSENSE_TASK_BATCH_CONCURRENCY=1`` to +# effectively serialise batches without changing the schema. +DEFAULT_SUBAGENT_BATCH_CONCURRENCY: int = _read_int_env( + "SURFSENSE_TASK_BATCH_CONCURRENCY", + default=3, +) + +# Max number of children in a single batched ``task`` call. Hard upper +# bound is a safety net for prompt-injection / runaway loops; the orchestrator +# rarely needs more than a handful of concurrent specialists. +MAX_SUBAGENT_BATCH_SIZE: int = _read_int_env( + "SURFSENSE_TASK_BATCH_MAX_SIZE", + default=8, +) + + +# Soft threshold for per-turn cumulative ``task(...)`` invocations across +# **all** subagents. Once the sum of ``state['billable_calls']`` values +# crosses this number, the runtime appends a one-shot warning ToolMessage +# instructing the orchestrator to wrap up the turn. Tunable so heavy-research +# turns (which legitimately need 15+ specialist calls) don't trip the alarm +# in production. Set to ``0`` to disable the warning entirely. +DEFAULT_SUBAGENT_BILLABLE_THRESHOLD: int = _read_int_env( + "SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD", + default=15, +) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py index 0119752c1..6cc71f252 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py @@ -16,6 +16,9 @@ from langchain.agents import create_agent from langchain.chat_models import init_chat_model from langgraph.types import Checkpointer +from app.agents.multi_agent_chat.subagents.shared.spec import ( + SURF_CONTEXT_HINT_PROVIDER_KEY, +) from app.utils.perf import get_perf_logger from .task_tool import build_task_tool_with_parent_config @@ -34,6 +37,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): subagents: list[SubAgent | CompiledSubAgent], system_prompt: str | None = TASK_SYSTEM_PROMPT, task_description: str | None = None, + search_space_id: int | None = None, ) -> None: self._surf_checkpointer = checkpointer super(SubAgentMiddleware, self).__init__() @@ -43,8 +47,17 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): ) self._backend = backend self._subagents = subagents + # Search-space id is captured at build time (the orchestrator runs in + # exactly one search space for its lifetime). The spawn-paused kill + # switch keys on it so an operator can quarantine one workspace + # without affecting the rest of the deployment. + self._search_space_id = search_space_id subagent_specs = self._surf_compile_subagent_graphs() - task_tool = build_task_tool_with_parent_config(subagent_specs, task_description) + task_tool = build_task_tool_with_parent_config( + subagent_specs, + task_description, + search_space_id=search_space_id, + ) if system_prompt and subagent_specs: agents_desc = "\n".join( f"- {s['name']}: {s['description']}" for s in subagent_specs @@ -64,6 +77,10 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): for spec in self._subagents: spec_start = time.perf_counter() + # Provider may be ``None`` (no hint), in which case task_tool + # skips the prepend step. We forward the key unconditionally so + # the registry shape is uniform. + hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY) if "runnable" in spec: compiled = cast(CompiledSubAgent, spec) specs.append( @@ -71,6 +88,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): "name": compiled["name"], "description": compiled["description"], "runnable": compiled["runnable"], + SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider, } ) timings.append( @@ -108,6 +126,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): "name": spec["name"], "description": spec["description"], "runnable": runnable, + SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider, } ) timings.append( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/spawn_paused.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/spawn_paused.py new file mode 100644 index 000000000..2c9e114e0 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/spawn_paused.py @@ -0,0 +1,84 @@ +"""Per-search-space spawn-paused kill switch for the ``task`` boundary. + +When operators see a runaway loop, a vendor outage, or a billing event +that requires immediate cessation of subagent traffic for a specific +workspace, they flip a Redis flag and the ``task`` tool short-circuits +without touching downstream services. The flag is **per-search-space** +so one tenant's incident never silences the rest of the deployment. + +Flag key: ``surfsense:spawn_paused:{search_space_id}`` +Flag value: any string-truthy value (we read presence, not contents). +TTL: set by whoever toggles the flag — this module never expires + keys on its own, since "the flag is on" is itself the signal + that a human (or alert) needs to investigate. + +The check is best-effort: Redis errors are logged but do not block the +``task`` invocation. Failing closed (block-on-redis-error) would let a +single Redis blip take the whole orchestrator offline; failing open +preserves availability and the alarm bells (rate-limits, cost spikes) +will surface the underlying outage. +""" + +from __future__ import annotations + +import contextlib +import logging +import os + +from app.config import config + +logger = logging.getLogger(__name__) + + +# Operators can disable the check entirely (e.g. local dev without Redis) +# by setting ``SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=1``. Default is +# enabled so production never relies on flipping an opt-out flag. +_DISABLED = os.environ.get( + "SURFSENSE_TASK_SPAWN_PAUSED_DISABLED", "" +).strip().lower() in { + "1", + "true", + "yes", + "on", +} + + +def _flag_key(search_space_id: int) -> str: + return f"surfsense:spawn_paused:{search_space_id}" + + +async def is_spawn_paused(search_space_id: int | None) -> bool: + """Return ``True`` iff the workspace's spawn-paused flag is set in Redis. + + A ``None`` search-space (e.g. dev paths that did not plumb the id + through yet) bypasses the check. So does a Redis outage — see module + docstring for the fail-open rationale. + """ + if _DISABLED or search_space_id is None: + return False + try: + # Local import keeps the cold-path import cheap and lets routes + # that never call ``task`` skip the redis dependency entirely. + import redis.asyncio as aioredis # type: ignore[import-not-found] + + client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True) + try: + raw = await client.get(_flag_key(search_space_id)) + finally: + # ``aclose()`` is the async-safe variant on redis-py >=5; fall back + # to ``close()`` for older clients pinned in tests. + close = getattr(client, "aclose", None) or getattr(client, "close", None) + if callable(close): + with contextlib.suppress(Exception): + await close() # type: ignore[misc] + return bool(raw) + except Exception: + logger.warning( + "spawn_paused check failed for search_space_id=%s; failing open.", + search_space_id, + exc_info=True, + ) + return False + + +__all__ = ["is_spawn_paused"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py index f6a9ff146..eaed9a55f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py @@ -8,9 +8,12 @@ re-raises any new pending interrupt back to the parent. from __future__ import annotations +import asyncio +import json import logging import time -from typing import Annotated, Any, NoReturn +from collections.abc import Awaitable +from typing import Annotated, Any, NoReturn, TypeVar from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION from langchain.tools import BaseTool, ToolRuntime @@ -20,6 +23,11 @@ from langchain_core.tools import StructuredTool from langgraph.errors import GraphInterrupt from langgraph.types import Command, Interrupt +from app.agents.multi_agent_chat.subagents.shared.spec import ( + SURF_CONTEXT_HINT_PROVIDER_KEY, + ContextHintProvider, +) +from app.observability import metrics as ot_metrics, otel as ot from app.utils.perf import get_perf_logger from .config import ( @@ -28,7 +36,13 @@ from .config import ( has_surfsense_resume, subagent_invoke_config, ) -from .constants import EXCLUDED_STATE_KEYS +from .constants import ( + DEFAULT_SUBAGENT_BATCH_CONCURRENCY, + DEFAULT_SUBAGENT_BILLABLE_THRESHOLD, + DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS, + EXCLUDED_STATE_KEYS, + MAX_SUBAGENT_BATCH_SIZE, +) from .propagation import wrap_with_tool_call_id from .resume import ( build_resume_command, @@ -36,11 +50,70 @@ from .resume import ( get_first_pending_subagent_interrupt, hitlrequest_action_count, ) +from .spawn_paused import is_spawn_paused logger = logging.getLogger(__name__) _perf_log = get_perf_logger() +class SubagentInvokeTimeoutError(Exception): + """Raised when ``subagent.ainvoke`` exceeds the configured wall-clock budget. + + Carries the subagent name and the elapsed seconds so the caller can + synthesize a ToolMessage that the orchestrator can act on (re-route, + surface to the user, or retry with a smaller scope). + """ + + def __init__(self, subagent_type: str, elapsed_seconds: float) -> None: + super().__init__( + f"subagent {subagent_type!r} exceeded " + f"{DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS:.0f}s budget " + f"(elapsed={elapsed_seconds:.1f}s)" + ) + self.subagent_type = subagent_type + self.elapsed_seconds = elapsed_seconds + + +_T = TypeVar("_T") + + +async def _ainvoke_with_timeout[T]( + coro: Awaitable[_T], *, subagent_type: str, started_at: float +) -> _T: + """Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``. + + A non-positive timeout disables the cap (configurable via the + ``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the + underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is + raised — the caller wraps it into a synthetic ToolMessage so the + orchestrator can decide what to do. + """ + timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS + if timeout <= 0: + return await coro + try: + return await asyncio.wait_for(coro, timeout=timeout) + except TimeoutError as exc: + elapsed = time.perf_counter() - started_at + raise SubagentInvokeTimeoutError(subagent_type, elapsed) from exc + + +def _synthesize_timeout_command( + exc: SubagentInvokeTimeoutError, *, tool_call_id: str +) -> Command: + """Turn a :class:`SubagentInvokeTimeoutError` into a ToolMessage the parent can read.""" + content = ( + f"Subagent {exc.subagent_type!r} timed out after " + f"{exc.elapsed_seconds:.1f}s (budget=" + f"{DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS:.0f}s). " + "The work was cancelled. Treat as status=error; re-route with a " + "narrower scope or different specialist." + ) + return Command( + update={"messages": [ToolMessage(content=content, tool_call_id=tool_call_id)]} + ) + + def _reraise_stamped_subagent_interrupt( gi: GraphInterrupt, tool_call_id: str ) -> NoReturn: @@ -69,11 +142,24 @@ def _reraise_stamped_subagent_interrupt( def build_task_tool_with_parent_config( subagents: list[dict[str, Any]], task_description: str | None = None, + *, + search_space_id: int | None = None, ) -> BaseTool: """Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging.""" subagent_graphs: dict[str, Runnable] = { spec["name"]: spec["runnable"] for spec in subagents } + # Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``). + # The mapping is sparse: only routes that opted in via ``pack_subagent`` + # appear here, and the value is invoked once per ``task(...)`` call to + # generate a short string prepended to the subagent's first + # ``HumanMessage``. Failures are logged and swallowed — a broken hint + # provider must never prevent the underlying task from running. + subagent_hint_providers: dict[str, ContextHintProvider] = { + spec["name"]: provider + for spec in subagents + if (provider := spec.get(SURF_CONTEXT_HINT_PROVIDER_KEY)) is not None + } subagent_description_str = "\n".join( f"- {s['name']}: {s['description']}" for s in subagents ) @@ -87,6 +173,120 @@ def build_task_tool_with_parent_config( else: description = task_description + def _billable_call_update( + subagent_type: str, runtime: ToolRuntime + ) -> dict[str, Any]: + """Build the per-call ``billable_calls`` delta + an optional warning. + + The orchestrator's ``billable_calls`` map is summed by + :func:`_int_counter_merge_reducer`, so we always emit + ``{subagent_type: 1}`` and let the reducer accumulate. If the + cumulative count *after* this call would cross the configured + threshold, we also slip a soft ``messages`` entry into the update + so the orchestrator can read it on its next step and self-limit. + Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps + the helper composable with the existing single/batch return paths. + """ + delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}} + threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD + if threshold <= 0: + return delta + prior = runtime.state.get("billable_calls") or {} + # ``prior`` may be a plain dict or a reducer-managed mapping; only + # int values are counted so a malformed checkpoint can't crash us. + prior_total = sum(v for v in prior.values() if isinstance(v, int)) + new_total = prior_total + 1 + if prior_total < threshold <= new_total: + warn = ( + f"[budget warning] This turn has dispatched {new_total} " + f"subagent calls (soft cap = {threshold}). Wrap up the " + "user's request with what you have rather than launching " + "more specialists; surface a partial answer if needed." + ) + delta["_billable_warn_text"] = warn + return delta + + def _attach_billable( + cmd: Command, subagent_type: str, runtime: ToolRuntime + ) -> Command: + """Merge the per-call billable counter (and warning) into ``cmd``.""" + delta = _billable_call_update(subagent_type, runtime) + warn_text = delta.pop("_billable_warn_text", None) + # ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively + # copy so we don't mutate state shared across other tool returns. + update = dict(getattr(cmd, "update", {}) or {}) + for key, value in delta.items(): + update[key] = value + if warn_text: + existing_msgs = list(update.get("messages") or []) + existing_msgs.append( + ToolMessage(content=warn_text, tool_call_id=runtime.tool_call_id) + ) + update["messages"] = existing_msgs + return Command(update=update) + + def _safe_message_text(msg: Any) -> str: + """Pull text out of a BaseMessage without trusting the ``.text`` property. + + ``BaseMessage.text`` walks ``content_blocks`` and crashes with + ``TypeError: 'NoneType' object is not iterable`` when ``content`` is + ``None`` (common for tool-call AIMessages whose payload is purely + structured). ``getattr(msg, "text", None)`` does not catch this + because Python evaluates the property body before falling back to + the default. Read ``content`` directly and coerce defensively. + """ + try: + content = getattr(msg, "content", None) + except Exception: + content = None + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict): + block_text = block.get("text") or block.get("content") + if isinstance(block_text, str): + parts.append(block_text) + return " ".join(parts) + return str(content) + + def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]: + """Compress the subagent's message stream into a compact tool trace. + + Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview": + <≤120 chars>}`` so the orchestrator can show "this is what your + specialist actually did" without dumping the full message stream + back through the prompt. The list is attached to the returned + ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``); + the LLM never sees it, but UI / observability code can pluck it + out of the checkpoint. + """ + trace: list[dict[str, Any]] = [] + for msg in messages: + tool_name = getattr(msg, "name", None) + tool_call_id_attr = getattr(msg, "tool_call_id", None) + if not tool_name and not tool_call_id_attr: + # Only ToolMessages have either field; skip AIMessage / + # HumanMessage / SystemMessage frames. + continue + status = getattr(msg, "status", None) or "ok" + preview = _safe_message_text(msg).strip().replace("\n", " ") + if len(preview) > 120: + preview = preview[:117] + "..." + trace.append( + { + "tool": tool_name or "<unknown>", + "status": status, + "preview": preview, + } + ) + return trace + def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command: if "messages" not in result: msg = ( @@ -105,15 +305,51 @@ def build_task_tool_with_parent_config( "output to forward back to the user." ) raise ValueError(msg) - last_text = getattr(messages[-1], "text", None) or "" - message_text = last_text.rstrip() + message_text = _safe_message_text(messages[-1]).rstrip() + # Tool-trace is purely observability — wrap defensively so a single + # malformed frame never bubbles up and kills the whole user turn. + try: + tool_trace = _build_tool_trace(messages) + except Exception: + logger.exception( + "Failed to build tool_trace for subagent return; " + "continuing without trace." + ) + tool_trace = [] + tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id) + if tool_trace: + # ``additional_kwargs`` is a free-form dict on BaseMessage; using + # a ``surf_`` prefix avoids collision with provider-specific keys + # (e.g. Anthropic's ``cache_control``). The LLM doesn't see it; + # consumers (UI, observability) read it off the checkpoint. + tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace return Command( update={ **state_update, - "messages": [ToolMessage(message_text, tool_call_id=tool_call_id)], + "messages": [tool_msg], } ) + def _resolve_context_hint( + subagent_type: str, description: str, runtime: ToolRuntime + ) -> str | None: + """Run the per-subagent hint provider; swallow & log any exception.""" + provider = subagent_hint_providers.get(subagent_type) + if provider is None: + return None + try: + hint = provider(runtime.state, description) + except Exception: + logger.exception( + "Context-hint provider for subagent %r raised; skipping hint.", + subagent_type, + ) + return None + if not hint or not isinstance(hint, str): + return None + cleaned = hint.strip() + return cleaned or None + def _validate_and_prepare_state( subagent_type: str, description: str, runtime: ToolRuntime ) -> tuple[Runnable, dict]: @@ -121,20 +357,306 @@ def build_task_tool_with_parent_config( subagent_state = { k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS } - subagent_state["messages"] = [HumanMessage(content=description)] + hint = _resolve_context_hint(subagent_type, description, runtime) + if hint: + # Prepend as a tagged block so the subagent prompt can pattern-match + # on the section (and a future change can lift it into its own + # ``SystemMessage`` if needed). + payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}" + else: + payload = description + subagent_state["messages"] = [HumanMessage(content=payload)] return subagent, subagent_state + def _merge_batch_results( + results: list[tuple[int, str, dict | str, dict | None]], + runtime: ToolRuntime, + ) -> Command: + """Combine per-child results into one Command with a combined ToolMessage. + + ``results`` is a list of ``(task_index, subagent_type, + payload_or_error_text, child_state_update)`` tuples — preserving the + input order so the orchestrator can map each block back to the task + it dispatched. State updates are merged by reducer for keys outside + :data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``, + etc.) is replaced by the synthesized aggregate ToolMessage. Every + child also contributes a ``billable_calls`` increment so cost + accounting matches single-mode dispatch. + """ + results.sort(key=lambda r: r[0]) + merged_state: dict[str, Any] = {} + billable_delta: dict[str, int] = {} + message_blocks: list[str] = [] + batch_trace: list[dict[str, Any]] = [] + for task_index, subagent_type, payload, state_update in results: + billable_delta[subagent_type] = billable_delta.get(subagent_type, 0) + 1 + if isinstance(payload, str): + # Pre-flight error or per-task exception text. + message_blocks.append(f"[task {task_index}] {payload}") + batch_trace.append( + { + "task_index": task_index, + "subagent_type": subagent_type, + "status": "error", + "tool_trace": [], + } + ) + continue + messages = payload.get("messages") or [] + last_text = _safe_message_text(messages[-1]).rstrip() if messages else "" + message_blocks.append(f"[task {task_index}] {last_text or '<empty>'}") + try: + child_trace = _build_tool_trace(messages) + except Exception: + logger.exception( + "Failed to build tool_trace for batch task_index=%d; continuing.", + task_index, + ) + child_trace = [] + batch_trace.append( + { + "task_index": task_index, + "subagent_type": subagent_type, + "status": "ok", + "tool_trace": child_trace, + } + ) + if state_update: + # Naive merge: later tasks win on scalar collisions; reducer-backed + # fields (``receipts``, ``files`` etc.) accumulate at apply time. + merged_state.update(state_update) + aggregate = "\n\n".join(message_blocks) + aggregate_msg = ToolMessage( + content=aggregate, tool_call_id=runtime.tool_call_id + ) + if batch_trace: + aggregate_msg.additional_kwargs["surf_tool_trace"] = batch_trace + update: dict[str, Any] = { + **merged_state, + "billable_calls": billable_delta, + "messages": [aggregate_msg], + } + # Soft-cap warning: check the cumulative count after attribution. + threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD + if threshold > 0: + prior = runtime.state.get("billable_calls") or {} + prior_total = sum(v for v in prior.values() if isinstance(v, int)) + new_total = prior_total + sum(billable_delta.values()) + if prior_total < threshold <= new_total: + update["messages"].append( + ToolMessage( + content=( + f"[budget warning] This turn has dispatched " + f"{new_total} subagent calls (soft cap = " + f"{threshold}). Wrap up the user's request with " + "what you have rather than launching more " + "specialists; surface a partial answer if needed." + ), + tool_call_id=runtime.tool_call_id, + ) + ) + return Command(update=update) + + async def _ainvoke_one_batch_child( + *, + task_index: int, + subagent_type: str, + description: str, + runtime: ToolRuntime, + semaphore: asyncio.Semaphore, + ) -> tuple[int, str, dict | str, dict | None]: + """Run one child of a batched ``task`` call under the concurrency cap. + + Errors are returned as plain text in slot 2 so a single child's + failure does not abort the whole batch. ``GraphInterrupt`` from a + batched child is currently treated as a hard failure for that child + only — batched HITL is intentionally out of scope for the v1 + rollout (see plan tier 2 item 4 risks). + """ + async with semaphore: + if subagent_type not in subagent_graphs: + allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) + return ( + task_index, + subagent_type, + ( + f"Subagent {subagent_type!r} does not exist; " + f"allowed: {allowed_types}" + ), + None, + ) + subagent, subagent_state = _validate_and_prepare_state( + subagent_type, description, runtime + ) + sub_config = subagent_invoke_config(runtime) + started_at = time.perf_counter() + try: + result = await _ainvoke_with_timeout( + subagent.ainvoke(subagent_state, config=sub_config), + subagent_type=subagent_type, + started_at=started_at, + ) + except SubagentInvokeTimeoutError as exc: + logger.warning( + "Batch child %d (%s) timed out after %.1fs", + task_index, + subagent_type, + exc.elapsed_seconds, + ) + return (task_index, subagent_type, str(exc), None) + except GraphInterrupt: + # Batched HITL is unsupported in v1 — surface as a failure + # for this child so the rest of the batch still completes. + logger.warning( + "Batch child %d (%s) raised GraphInterrupt; batched HITL " + "is not supported. Re-dispatch this task as a single " + "(non-batched) `task(...)` call to get the HITL prompt.", + task_index, + subagent_type, + ) + return ( + task_index, + subagent_type, + ( + f"Subagent {subagent_type!r} needs human approval. " + "Re-dispatch this task as a single (non-batched) " + "`task(...)` call so the approval card can be shown." + ), + None, + ) + except Exception as exc: + logger.exception( + "Batch child %d (%s) raised: %s", + task_index, + subagent_type, + exc, + ) + return ( + task_index, + subagent_type, + f"Subagent {subagent_type!r} error: {exc}", + None, + ) + child_state_update = { + k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS + } + return (task_index, subagent_type, result, child_state_update) + + def _coerce_batch_arg(tasks: Any) -> list[dict] | str: + """Rescue common LLM-side malformations of the ``tasks`` argument. + + Some providers serialise an array argument as a JSON-encoded string, + and small models occasionally hand back a single ``{description, + subagent_type}`` dict instead of a one-element array. Both are + recovered here with a WARN log so the issue is visible in metrics + but the user's turn still completes; truly broken shapes return a + plain string that the caller surfaces as the tool error. + """ + if isinstance(tasks, list): + return tasks + if isinstance(tasks, dict): + logger.warning( + "task: `tasks` was a single dict; coercing to a 1-element list. " + "Orchestrators should send `tasks=[{...}]` directly." + ) + return [tasks] + if isinstance(tasks, str): + stripped = tasks.strip() + if not stripped: + return "tasks: argument is empty." + try: + parsed = json.loads(stripped) + except json.JSONDecodeError as exc: + return ( + f"tasks: argument is a string but not valid JSON ({exc.msg}). " + "Send a JSON array of `{description, subagent_type}` objects." + ) + logger.warning( + "task: `tasks` was a JSON-encoded string; parsed to %s. " + "Orchestrators should send a JSON array directly.", + type(parsed).__name__, + ) + return _coerce_batch_arg(parsed) + return ( + f"tasks: unsupported type {type(tasks).__name__}; expected an array " + "of `{description, subagent_type}` objects." + ) + + async def _adispatch_batch( + tasks: list[dict], runtime: ToolRuntime + ) -> Command | str: + """Fan-out helper for the ``tasks`` array shape. + + Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped + at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single + :class:`Command` that the LLM sees as one ToolMessage per child, + prefixed with ``[task <index>]`` so it can map back to the input + order. + """ + if not tasks: + return "tasks: array is empty; nothing to dispatch." + if len(tasks) > MAX_SUBAGENT_BATCH_SIZE: + return ( + f"tasks: too many children ({len(tasks)}); " + f"max is {MAX_SUBAGENT_BATCH_SIZE}. Split the batch." + ) + normalized: list[tuple[int, str, str]] = [] + for idx, item in enumerate(tasks): + if not isinstance(item, dict): + return ( + f"tasks[{idx}]: must be an object with description+subagent_type." + ) + description = item.get("description") + subagent_type = item.get("subagent_type") + if not isinstance(description, str) or not description.strip(): + return f"tasks[{idx}]: missing or empty 'description'." + if not isinstance(subagent_type, str) or not subagent_type.strip(): + return f"tasks[{idx}]: missing or empty 'subagent_type'." + normalized.append((idx, subagent_type.strip(), description)) + semaphore = asyncio.Semaphore(DEFAULT_SUBAGENT_BATCH_CONCURRENCY) + coros = [ + _ainvoke_one_batch_child( + task_index=idx, + subagent_type=subagent_type, + description=description, + runtime=runtime, + semaphore=semaphore, + ) + for idx, subagent_type, description in normalized + ] + results = await asyncio.gather(*coros) + return _merge_batch_results(list(results), runtime) + def task( description: Annotated[ - str, - "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", - ], + str | None, + "Single-mode: a detailed task description for the subagent. Required unless `tasks` is provided.", + ] = None, subagent_type: Annotated[ - str, - "The type of subagent to use. Must be one of the available agent types listed in the tool description.", - ], - runtime: ToolRuntime, + str | None, + "Single-mode: the type of subagent to use. Required unless `tasks` is provided.", + ] = None, + runtime: ToolRuntime = None, # type: ignore[assignment] + tasks: Annotated[ + list[dict] | None, + ( + "Batch-mode: array of `{description, subagent_type}` objects. " + "Synchronous path does not support batch mode; orchestrators " + "must use the async event loop to fan out." + ), + ] = None, ) -> str | Command: + if tasks is not None: + return ( + "task: batch mode (`tasks=[...]`) is only supported on the async " + "path. SurfSense orchestrators always run in an event loop, so " + "this should never fire — file a bug if you see it." + ) + if not description or not subagent_type: + return ( + "task: must provide either single-mode (`description`+`subagent_type`) " + "or batch-mode (`tasks`)." + ) if subagent_type not in subagent_graphs: allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) return ( @@ -173,6 +695,9 @@ def build_task_tool_with_parent_config( exc_info=True, ) + invoke_path = "resume" if pending_value is not None else "fresh" + invoke_start = time.perf_counter() + invoke_outcome = "ok" if pending_value is not None: resume_value = consume_surfsense_resume(runtime) if resume_value is None: @@ -188,32 +713,157 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - try: - result = subagent.invoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) - except GraphInterrupt as gi: - _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + with ot.subagent_invoke_span( + subagent_type=subagent_type, path=invoke_path + ) as sp: + try: + result = subagent.invoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + sp.set_attribute("subagent.outcome", invoke_outcome) + except GraphInterrupt as gi: + invoke_outcome = "interrupted" + sp.set_attribute("subagent.outcome", invoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - invoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + except Exception: + invoke_outcome = "error" + sp.set_attribute("subagent.outcome", invoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - invoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + raise else: - try: - result = subagent.invoke(subagent_state, config=sub_config) - except GraphInterrupt as gi: - _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + with ot.subagent_invoke_span( + subagent_type=subagent_type, path=invoke_path + ) as sp: + try: + result = subagent.invoke(subagent_state, config=sub_config) + sp.set_attribute("subagent.outcome", invoke_outcome) + except GraphInterrupt as gi: + invoke_outcome = "interrupted" + sp.set_attribute("subagent.outcome", invoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - invoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + except Exception: + invoke_outcome = "error" + sp.set_attribute("subagent.outcome", invoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - invoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + raise + invoke_elapsed_ms = (time.perf_counter() - invoke_start) * 1000 + ot_metrics.record_subagent_invoke_duration( + invoke_elapsed_ms, + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=invoke_outcome, + ) return _return_command_with_state_update(result, runtime.tool_call_id) async def atask( description: Annotated[ - str, - "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", - ], + str | None, + "Single-mode: a detailed task description for the subagent. Required unless `tasks` is provided.", + ] = None, subagent_type: Annotated[ - str, - "The type of subagent to use. Must be one of the available agent types listed in the tool description.", - ], - runtime: ToolRuntime, + str | None, + "Single-mode: the type of subagent to use. Required unless `tasks` is provided.", + ] = None, + runtime: ToolRuntime = None, # type: ignore[assignment] + tasks: Annotated[ + list[dict] | None, + ( + "Batch-mode: array of `{description, subagent_type}` objects " + "to fan out concurrently (max " + f"{MAX_SUBAGENT_BATCH_SIZE}, concurrency " + f"{DEFAULT_SUBAGENT_BATCH_CONCURRENCY}). Mutually exclusive " + "with single-mode args. Batched children do not support " + "human-in-the-loop interrupts; re-dispatch as single mode " + "if a child needs approval." + ), + ] = None, ) -> str | Command: atask_start = time.perf_counter() + # Kill switch: when ops flips the spawn-paused flag for this + # workspace, every ``task(...)`` invocation (single- or batch-mode) + # short-circuits with a clear ToolMessage so the orchestrator can + # tell the user what happened and stop hammering downstream APIs. + if await is_spawn_paused(search_space_id): + logger.warning( + "[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s", + search_space_id, + runtime.tool_call_id, + ) + return ( + "task: subagent dispatch is currently paused for this workspace. " + "Acknowledge to the user that delegation is temporarily disabled " + "(ops kill switch); do not retry until the pause is lifted." + ) + if tasks is not None: + if description or subagent_type: + return ( + "task: cannot combine `tasks` with `description`/`subagent_type`. " + "Use either single-mode (description+subagent_type) or batch-mode (tasks)." + ) + if not runtime.tool_call_id: + raise ValueError("Tool call ID is required for subagent invocation") + coerced = _coerce_batch_arg(tasks) + if isinstance(coerced, str): + return coerced + logger.info( + "[hitl_route] atask BATCH ENTRY: size=%d tool_call_id=%s", + len(coerced), + runtime.tool_call_id, + ) + return await _adispatch_batch(coerced, runtime) + if not description or not subagent_type: + return ( + "task: must provide either single-mode (`description`+`subagent_type`) " + "or batch-mode (`tasks`)." + ) logger.info( "[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s", subagent_type, @@ -274,40 +924,154 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - try: - result = await subagent.ainvoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) - except GraphInterrupt as gi: - ainvoke_outcome = "interrupted" - _perf_log.info( - "[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s " - "aget_state=%.3fs ainvoke=%.3fs total=%.3fs", - subagent_type, - invoke_path, - ainvoke_outcome, - aget_state_elapsed, - time.perf_counter() - ainvoke_start, - time.perf_counter() - atask_start, - ) - _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + with ot.subagent_invoke_span( + subagent_type=subagent_type, path=invoke_path + ) as sp: + try: + result = await _ainvoke_with_timeout( + subagent.ainvoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ), + subagent_type=subagent_type, + started_at=ainvoke_start, + ) + sp.set_attribute("subagent.outcome", ainvoke_outcome) + except SubagentInvokeTimeoutError as exc: + ainvoke_outcome = "timeout" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + logger.warning( + "Subagent %r ainvoke (resume) timed out after %.1fs", + subagent_type, + exc.elapsed_seconds, + ) + return _synthesize_timeout_command( + exc, tool_call_id=runtime.tool_call_id + ) + except GraphInterrupt as gi: + ainvoke_outcome = "interrupted" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + _perf_log.info( + "[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s " + "aget_state=%.3fs ainvoke=%.3fs total=%.3fs", + subagent_type, + invoke_path, + ainvoke_outcome, + aget_state_elapsed, + time.perf_counter() - ainvoke_start, + time.perf_counter() - atask_start, + ) + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + except Exception: + ainvoke_outcome = "error" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + raise else: - try: - result = await subagent.ainvoke(subagent_state, config=sub_config) - except GraphInterrupt as gi: - ainvoke_outcome = "interrupted" - _perf_log.info( - "[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s " - "aget_state=%.3fs ainvoke=%.3fs total=%.3fs", - subagent_type, - invoke_path, - ainvoke_outcome, - aget_state_elapsed, - time.perf_counter() - ainvoke_start, - time.perf_counter() - atask_start, - ) - _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + with ot.subagent_invoke_span( + subagent_type=subagent_type, path=invoke_path + ) as sp: + try: + result = await _ainvoke_with_timeout( + subagent.ainvoke(subagent_state, config=sub_config), + subagent_type=subagent_type, + started_at=ainvoke_start, + ) + sp.set_attribute("subagent.outcome", ainvoke_outcome) + except SubagentInvokeTimeoutError as exc: + ainvoke_outcome = "timeout" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + logger.warning( + "Subagent %r ainvoke (fresh) timed out after %.1fs", + subagent_type, + exc.elapsed_seconds, + ) + return _synthesize_timeout_command( + exc, tool_call_id=runtime.tool_call_id + ) + except GraphInterrupt as gi: + ainvoke_outcome = "interrupted" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + _perf_log.info( + "[hitl_route] atask EXIT subagent_type=%r path=%s outcome=%s " + "aget_state=%.3fs ainvoke=%.3fs total=%.3fs", + subagent_type, + invoke_path, + ainvoke_outcome, + aget_state_elapsed, + time.perf_counter() - ainvoke_start, + time.perf_counter() - atask_start, + ) + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) + except Exception: + ainvoke_outcome = "error" + sp.set_attribute("subagent.outcome", ainvoke_outcome) + ot_metrics.record_subagent_invoke_duration( + (time.perf_counter() - ainvoke_start) * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + raise ainvoke_elapsed = time.perf_counter() - ainvoke_start except GraphInterrupt: raise @@ -326,7 +1090,18 @@ def build_task_tool_with_parent_config( merge_elapsed, time.perf_counter() - atask_start, ) - return cmd + ot_metrics.record_subagent_invoke_duration( + ainvoke_elapsed * 1000, + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + ot_metrics.record_subagent_invoke_outcome( + subagent_type=subagent_type, + path=invoke_path, + outcome=ainvoke_outcome, + ) + return _attach_billable(cmd, subagent_type, runtime) return StructuredTool.from_function( name="task", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py index e8a4c9899..2685d8a9b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/kb_context_projection.py @@ -52,9 +52,7 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg] messages.insert(insert_at, SystemMessage(content=tree_text)) priority_count = 0 if priority: - priority_count = ( - len(priority) if hasattr(priority, "__len__") else 1 - ) + priority_count = len(priority) if hasattr(priority, "__len__") else 1 messages.insert(insert_at, _render_priority_message(priority)) _perf_log.info( "[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py index d61d38f34..3db51883d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py @@ -17,7 +17,7 @@ from langchain_core.tools import BaseTool from langgraph.types import interrupt from app.agents.new_chat.permissions import Rule -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot from .decision import normalize_permission_decision from .payload import PERMISSION_ASK_INTERRUPT_TYPE, build_permission_ask_payload @@ -52,6 +52,8 @@ def request_permission_decision( ), ot.interrupt_span(interrupt_type=PERMISSION_ASK_INTERRUPT_TYPE), ): + ot_metrics.record_permission_ask(permission=tool_name) + ot_metrics.record_interrupt(interrupt_type=PERMISSION_ASK_INTERRUPT_TYPE) decision = interrupt(payload) return normalize_permission_decision(decision) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py index c1ebe31ca..3b20d8915 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py @@ -173,6 +173,7 @@ def build_main_agent_deepagent_middleware( subagents=subagents, system_prompt=None, task_description=TASK_TOOL_DESCRIPTION, + search_space_id=search_space_id, ), resilience.model_call_limit, resilience.tool_call_limit, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md index c44f131bb..413791037 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/system_prompt.md @@ -42,14 +42,16 @@ Return **only** one JSON object (no markdown/prose): "evidence": { "artifact_type": "report" | "podcast" | "video_presentation" | "resume" | "image" | null, "artifact_id": string | null, - "artifact_location": string | null + "artifact_location": string | null, + "receipts": Receipt[] | null }, "next_step": string | null, "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` -> `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` -> `next_step` must be non-null. -- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null. +Route-specific rules: +- `evidence.receipts` quotes the Receipt(s) returned by `generate_report` / `generate_podcast` / `generate_video_presentation` / `generate_resume` / `generate_image` this turn, verbatim. The Receipt's `type` enum is one of `report` | `podcast` | `video_presentation` | `resume` | `image`. +<include snippet="output_contract_base"/> </output_contract> + +<include snippet="verifiable_handle"/> diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index ab9dbc0ea..094371760 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -4,11 +4,15 @@ import hashlib import logging from typing import Any +from langchain.tools import ToolRuntime from langchain_core.tools import tool +from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.config import config from app.db import ( ImageGeneration, @@ -59,15 +63,22 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, + image_generation_config_id_override: int | None = None, ): - """Create ``generate_image`` with bound search space; DB work uses a per-call session.""" + """Create ``generate_image`` with bound search space; DB work uses a per-call session. + + ``image_generation_config_id_override``: when set (automations running on a + captured model), use this config id instead of reading the search space's + live ``image_generation_config_id``. + """ del db_session # use a fresh per-call session, see below @tool async def generate_image( prompt: str, + runtime: ToolRuntime, n: int = 1, - ) -> dict[str, Any]: + ) -> Command: """ Generate an image from a text description using AI image models. @@ -82,22 +93,48 @@ def create_generate_image_tool( Returns: A dictionary containing the generated image(s) for display in the chat. """ + + def _failed(payload: dict[str, Any], *, error: str) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="image", + operation="generate", + status="failed", + preview=prompt[:200] if prompt else None, + error=error, + ), + tool_call_id=runtime.tool_call_id, + ) + try: # Use a per-call session so concurrent tool calls don't share an # AsyncSession (which is not concurrency-safe). The streaming # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return {"error": "Search space not found"} + if image_generation_config_id_override is not None: + # Automation run: use the captured image model, insulated from + # later search-space changes. No search-space read needed. + config_id = ( + image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + ) + else: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) + config_id = ( + search_space.image_generation_config_id + or IMAGE_GEN_AUTO_MODE_ID + ) # Build generation kwargs # NOTE: size, quality, and style are intentionally NOT passed. @@ -112,19 +149,19 @@ def create_generate_image_tool( # Call litellm based on config type if is_image_gen_auto_mode(config_id): if not ImageGenRouterService.is_initialized(): - return { - "error": "No image generation models configured. " + err = ( + "No image generation models configured. " "Please add an image model in Settings > Image Models." - } + ) + return _failed({"error": err}, error=err) response = await ImageGenRouterService.aimage_generation( prompt=prompt, model="auto", **gen_kwargs ) elif config_id < 0: cfg = _get_global_image_gen_config(config_id) if not cfg: - return { - "error": f"Image generation config {config_id} not found" - } + err = f"Image generation config {config_id} not found" + return _failed({"error": err}, error=err) model_string = _build_model_string( cfg.get("provider", ""), @@ -151,9 +188,8 @@ def create_generate_image_tool( ) db_cfg = cfg_result.scalars().first() if not db_cfg: - return { - "error": f"Image generation config {config_id} not found" - } + err = f"Image generation config {config_id} not found" + return _failed({"error": err}, error=err) model_string = _build_model_string( db_cfg.provider.value, @@ -200,7 +236,10 @@ def create_generate_image_tool( # Extract image URLs from response images = response_dict.get("data", []) if not images: - return {"error": "No images were generated"} + return _failed( + {"error": "No images were generated"}, + error="No images were generated", + ) first_image = images[0] revised_prompt = first_image.get("revised_prompt", prompt) @@ -219,11 +258,14 @@ def create_generate_image_tool( f"{db_image_gen_id}/image?token={access_token}" ) else: - return {"error": "No displayable image data in the response"} + return _failed( + {"error": "No displayable image data in the response"}, + error="No displayable image data in the response", + ) image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}" - return { + payload = { "id": image_id, "assetId": image_url, "src": image_url, @@ -236,12 +278,26 @@ def create_generate_image_tool( "prompt": prompt, "image_count": len(images), } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="image", + operation="generate", + status="success", + external_id=str(db_image_gen_id), + verifiable_url=image_url, + preview=(revised_prompt or prompt)[:200], + ), + tool_call_id=runtime.tool_call_id, + ) except Exception as e: logger.exception("Image generation failed in tool") - return { - "error": f"Image generation failed: {e!s}", - "prompt": prompt, - } + err = f"Image generation failed: {e!s}" + return _failed( + {"error": err, "prompt": prompt}, + error=err, + ) return generate_image diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index 5f76f1d52..ddfcbd7fb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,5 +51,8 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], + image_generation_config_id_override=d.get( + "image_generation_config_id_override" + ), ), ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py index 55d9b3565..298257799 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py @@ -1,12 +1,28 @@ -"""Factory for a podcast-generation tool that queues background work and returns an ID for polling.""" +"""Factory for a podcast-generation tool. +Dispatches the heavy generation to Celery and then polls the podcast row +until it reaches a terminal status (READY/FAILED). The tool always +returns a real terminal ``Receipt`` — never a pending one. The wait is +bounded by the existing per-invocation safety net +(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode, +HTTP / process lifetime in single-agent mode). +""" + +import logging from typing import Any +from langchain.tools import ToolRuntime from langchain_core.tools import tool +from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.shared.deliverable_wait import wait_for_deliverable +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.db import Podcast, PodcastStatus, shielded_async_session +logger = logging.getLogger(__name__) + def create_generate_podcast_tool( search_space_id: int, @@ -19,9 +35,10 @@ def create_generate_podcast_tool( @tool async def generate_podcast( source_content: str, + runtime: ToolRuntime, podcast_title: str = "SurfSense Podcast", user_prompt: str | None = None, - ) -> dict[str, Any]: + ) -> Command: """ Generate a podcast from the provided content. @@ -70,23 +87,99 @@ def create_generate_podcast_tool( user_prompt=user_prompt, ) - print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}") + logger.info( + "[generate_podcast] Created podcast %s, task: %s", + podcast_id, + task.id, + ) - return { - "status": PodcastStatus.PENDING.value, + # Wait until the Celery worker flips the row to a terminal + # state. The wait is bounded only by the subagent invoke + # timeout (multi-agent) or HTTP lifetime (single-agent) — + # see app.agents.shared.deliverable_wait for details. + terminal_status, columns, elapsed = await wait_for_deliverable( + model=Podcast, + row_id=podcast_id, + columns=[Podcast.status, Podcast.file_location], + terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, + ) + + if terminal_status == PodcastStatus.READY: + file_location = columns[1] if columns else None + logger.info( + "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", + podcast_id, + elapsed, + file_location, + ) + payload: dict[str, Any] = { + "status": PodcastStatus.READY.value, + "podcast_id": podcast_id, + "title": podcast_title, + "file_location": file_location, + "message": ("Podcast generated and saved to your podcast panel."), + } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="podcast", + operation="generate", + status="success", + external_id=str(podcast_id), + preview=podcast_title, + ), + tool_call_id=runtime.tool_call_id, + ) + + # Only other terminal state is FAILED. + logger.warning( + "[generate_podcast] Podcast %s FAILED in %.2fs", + podcast_id, + elapsed, + ) + err = "Background worker reported FAILED status for this podcast." + payload = { + "status": PodcastStatus.FAILED.value, "podcast_id": podcast_id, "title": podcast_title, - "message": "Podcast generation started. This may take a few minutes.", + "error": err, } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="podcast", + operation="generate", + status="failed", + external_id=str(podcast_id), + preview=podcast_title, + error=err, + ), + tool_call_id=runtime.tool_call_id, + ) except Exception as e: error_message = str(e) - print(f"[generate_podcast] Error: {error_message}") - return { + logger.exception("[generate_podcast] Error: %s", error_message) + payload = { "status": PodcastStatus.FAILED.value, "error": error_message, "title": podcast_title, "podcast_id": None, } + receipt = make_receipt( + route="deliverables", + type="podcast", + operation="generate", + status="failed", + preview=podcast_title, + error=error_message, + ) + return with_receipt( + payload=payload, + receipt=receipt, + tool_call_id=runtime.tool_call_id, + ) return generate_podcast diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py index 385100c62..f12ca8a90 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/report.py @@ -6,10 +6,14 @@ import logging import re from typing import Any +from langchain.tools import ToolRuntime from langchain_core.callbacks import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.tools import tool +from langgraph.types import Command +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.db import Report, shielded_async_session from app.services.connector_service import ConnectorService from app.services.llm_service import get_document_summary_llm @@ -573,13 +577,14 @@ def create_generate_report_tool( @tool async def generate_report( topic: str, + runtime: ToolRuntime, source_content: str = "", source_strategy: str = "provided", search_queries: list[str] | None = None, report_style: str = "detailed", user_instructions: str | None = None, parent_report_id: int | None = None, - ) -> dict[str, Any]: + ) -> Command: """ Generate a structured Markdown report artifact from provided content. @@ -692,6 +697,23 @@ def create_generate_report_tool( parent_report_content: str | None = None report_group_id: int | None = None + def _failed(payload: dict[str, Any], *, error: str) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="report", + operation="generate", + status="failed", + external_id=str(payload.get("report_id")) + if payload.get("report_id") is not None + else None, + preview=topic, + error=error, + ), + tool_call_id=runtime.tool_call_id, + ) + async def _save_failed_report(error_msg: str) -> int | None: """Persist a failed report row using a short-lived session.""" try: @@ -753,12 +775,15 @@ def create_generate_report_tool( "No LLM configured. Please configure a language model in Settings." ) report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": topic, + }, + error=error_msg, + ) # Build the user instructions string user_instructions_section = "" @@ -971,12 +996,15 @@ def create_generate_report_tool( if not report_content or not isinstance(report_content, str): error_msg = "LLM returned empty or invalid content" report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": topic, + }, + error=error_msg, + ) # LLMs often wrap output in ```markdown ... ``` fences — strip them report_content = _strip_wrapping_code_fences(report_content) @@ -984,12 +1012,15 @@ def create_generate_report_tool( if not report_content: error_msg = "LLM returned empty or invalid content" report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": topic, + }, + error=error_msg, + ) # Strip any existing footer(s) carried over from parent version(s) while report_content.rstrip().endswith(_REPORT_FOOTER): @@ -1036,7 +1067,7 @@ def create_generate_report_tool( f"{metadata.get('section_count', 0)} sections" ) - return { + payload: dict[str, Any] = { "status": "ready", "report_id": saved_report_id, "title": topic, @@ -1045,17 +1076,32 @@ def create_generate_report_tool( "report_markdown": report_content, "message": f"Report generated successfully: {topic}", } + receipt = make_receipt( + route="deliverables", + type="report", + operation="generate", + status="success", + external_id=str(saved_report_id), + preview=topic, + ) + return with_receipt( + payload=payload, + receipt=receipt, + tool_call_id=runtime.tool_call_id, + ) except Exception as e: error_message = str(e) logger.exception(f"[generate_report] Error: {error_message}") report_id = await _save_failed_report(error_message) - - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": topic, - } + return _failed( + { + "status": "failed", + "error": error_message, + "report_id": report_id, + "title": topic, + }, + error=error_message, + ) return generate_report diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py index ece3ce241..ad16b7ba7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/resume.py @@ -8,10 +8,14 @@ from typing import Any import pypdf import typst +from langchain.tools import ToolRuntime from langchain_core.callbacks import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.tools import tool +from langgraph.types import Command +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.db import Report, shielded_async_session from app.services.llm_service import get_document_summary_llm @@ -429,10 +433,11 @@ def create_generate_resume_tool( @tool async def generate_resume( user_info: str, + runtime: ToolRuntime, user_instructions: str | None = None, parent_report_id: int | None = None, max_pages: int = 1, - ) -> dict[str, Any]: + ) -> Command: """ Generate a professional resume as a Typst document. @@ -476,6 +481,41 @@ def create_generate_resume_tool( template = _get_template() llm_reference = _build_llm_reference(template) + def _success(payload: dict[str, Any], *, report_id: int, title: str) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="resume", + operation="generate", + status="success", + external_id=str(report_id), + preview=title, + ), + tool_call_id=runtime.tool_call_id, + ) + + def _failed( + payload: dict[str, Any], + *, + report_id: int | None, + error: str, + title: str = "Resume", + ) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="resume", + operation="generate", + status="failed", + external_id=str(report_id) if report_id is not None else None, + preview=title, + error=error, + ), + tool_call_id=runtime.tool_call_id, + ) + async def _save_failed_report(error_msg: str) -> int | None: try: async with shielded_async_session() as session: @@ -514,13 +554,17 @@ def create_generate_resume_tool( except ValueError as e: error_msg = str(e) report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) # ── Phase 1: READ ───────────────────────────────────────────── async with shielded_async_session() as read_session: @@ -541,13 +585,17 @@ def create_generate_resume_tool( "No LLM configured. Please configure a language model in Settings." ) report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) # ── Phase 2: LLM GENERATION ─────────────────────────────────── @@ -588,13 +636,17 @@ def create_generate_resume_tool( if not body or not isinstance(body, str): error_msg = "LLM returned empty or invalid content" report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) body = _strip_typst_fences(body) body = _strip_imports(body) @@ -661,13 +713,17 @@ def create_generate_resume_tool( f"{compile_error or 'Unknown compile error'}" ) report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) actual_pages = _count_pdf_pages(pdf_bytes) if actual_pages <= validated_max_pages: @@ -700,13 +756,17 @@ def create_generate_resume_tool( ): error_msg = "LLM returned empty content while compressing resume" report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) body = _strip_typst_fences(compress_response.content) body = _strip_imports(body) @@ -718,13 +778,17 @@ def create_generate_resume_tool( f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}." ) report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_msg, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_msg, + ) # ── Phase 4: SAVE ───────────────────────────────────────────── dispatch_custom_event( @@ -768,32 +832,40 @@ def create_generate_resume_tool( logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}") - return { - "status": "ready", - "report_id": saved_id, - "title": resume_title, - "content_type": "typst", - "is_revision": bool(parent_content), - "message": ( - f"Resume generated successfully: {resume_title}" - if target_page_met - else ( - f"Resume generated, but could not fit the target of <= {validated_max_pages} " - f"page(s). Final length: {actual_pages} page(s)." - ) - ), - } + return _success( + { + "status": "ready", + "report_id": saved_id, + "title": resume_title, + "content_type": "typst", + "is_revision": bool(parent_content), + "message": ( + f"Resume generated successfully: {resume_title}" + if target_page_met + else ( + f"Resume generated, but could not fit the target of <= {validated_max_pages} " + f"page(s). Final length: {actual_pages} page(s)." + ) + ), + }, + report_id=saved_id, + title=resume_title, + ) except Exception as e: error_message = str(e) logger.exception(f"[generate_resume] Error: {error_message}") report_id = await _save_failed_report(error_message) - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } + return _failed( + { + "status": "failed", + "error": error_message, + "report_id": report_id, + "title": "Resume", + "content_type": "typst", + }, + report_id=report_id, + error=error_message, + ) return generate_resume diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py index a9f3447ab..5407c8834 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/video_presentation.py @@ -1,12 +1,29 @@ -"""Factory for a video-presentation tool that queues background work and returns an ID for polling.""" +"""Factory for a video-presentation tool. +Dispatches the heavy generation to Celery and then polls the +video-presentation row until it reaches a terminal status (READY/FAILED). +The tool always returns a real terminal ``Receipt`` — never a pending +one. The wait is bounded by the existing per-invocation safety net +(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode, +HTTP / process lifetime in single-agent mode). Video rendering can be +heavy; raise that ceiling if your generations routinely exceed it. +""" + +import logging from typing import Any +from langchain.tools import ToolRuntime from langchain_core.tools import tool +from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.shared.deliverable_wait import wait_for_deliverable +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session +logger = logging.getLogger(__name__) + def create_generate_video_presentation_tool( search_space_id: int, @@ -19,9 +36,10 @@ def create_generate_video_presentation_tool( @tool async def generate_video_presentation( source_content: str, + runtime: ToolRuntime, video_title: str = "SurfSense Presentation", user_prompt: str | None = None, - ) -> dict[str, Any]: + ) -> Command: """Generate a video presentation from the provided content. Use this tool when the user asks to create a video, presentation, slides, or slide deck. @@ -56,25 +74,100 @@ def create_generate_video_presentation_tool( user_prompt=user_prompt, ) - print( - f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}" + logger.info( + "[generate_video_presentation] Created video presentation %s, task: %s", + video_pres_id, + task.id, ) - return { - "status": VideoPresentationStatus.PENDING.value, + # Wait until the Celery worker flips the row to a terminal + # state. The wait is bounded only by the subagent invoke + # timeout (multi-agent) or HTTP lifetime (single-agent) — + # see app.agents.shared.deliverable_wait for details. + terminal_status, _columns, elapsed = await wait_for_deliverable( + model=VideoPresentation, + row_id=video_pres_id, + columns=[VideoPresentation.status], + terminal_statuses={ + VideoPresentationStatus.READY, + VideoPresentationStatus.FAILED, + }, + ) + + if terminal_status == VideoPresentationStatus.READY: + logger.info( + "[generate_video_presentation] %s READY in %.2fs", + video_pres_id, + elapsed, + ) + payload: dict[str, Any] = { + "status": VideoPresentationStatus.READY.value, + "video_presentation_id": video_pres_id, + "title": video_title, + "message": "Video presentation generated and saved.", + } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="video_presentation", + operation="generate", + status="success", + external_id=str(video_pres_id), + preview=video_title, + ), + tool_call_id=runtime.tool_call_id, + ) + + # Only other terminal state is FAILED. + logger.warning( + "[generate_video_presentation] %s FAILED in %.2fs", + video_pres_id, + elapsed, + ) + err = ( + "Background worker reported FAILED status for this video presentation." + ) + payload = { + "status": VideoPresentationStatus.FAILED.value, "video_presentation_id": video_pres_id, "title": video_title, - "message": "Video presentation generation started. This may take a few minutes.", + "error": err, } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="video_presentation", + operation="generate", + status="failed", + external_id=str(video_pres_id), + preview=video_title, + error=err, + ), + tool_call_id=runtime.tool_call_id, + ) except Exception as e: error_message = str(e) - print(f"[generate_video_presentation] Error: {error_message}") - return { + logger.exception("[generate_video_presentation] Error: %s", error_message) + payload = { "status": VideoPresentationStatus.FAILED.value, "error": error_message, "title": video_title, "video_presentation_id": None, } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="video_presentation", + operation="generate", + status="failed", + preview=video_title, + error=error_message, + ), + tool_call_id=runtime.tool_call_id, + ) return generate_video_presentation diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md index 2ae21c271..c4e36fc73 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_cloud.md @@ -150,11 +150,12 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: +<include snippet="output_contract_base"/> + +Route-specific rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output. +<include snippet="verifiable_handle"/> + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md index 4e5465aaf..25dafa3df 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/system_prompt_desktop.md @@ -33,7 +33,7 @@ Map outcomes to your `status`: - Any other `"Error: …"` → `status=error` and relay the tool's message verbatim as `next_step`. - HITL rejection → `status=blocked` with `next_step="User declined this filesystem action. Do not retry."`. -You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. `chunk_ids` apply only to `<priority_documents>` hits; for local-file operations leave them `null`. Never report values you did not actually see. +You construct the structured `evidence` fields from your own knowledge of what you called and what you observed — the tools do not return them. Never report values you did not actually see. (`chunk_ids` is always `null` in desktop mode — see "Chunk citations in your prose" below.) ## Chunk citations in your prose @@ -117,11 +117,12 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: +<include snippet="output_contract_base"/> + +Route-specific rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. - `evidence.content_excerpt`: max ~500 characters. Surface a short excerpt or a one-sentence summary, not the full file body. The supervisor already sees the tool's raw output. +<include snippet="verifiable_handle"/> + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md index 13f7b68a5..b656c5019 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/system_prompt.md @@ -6,7 +6,7 @@ Persist durable preferences/facts/instructions with `update_memory` while avoidi </goal> <visibility_scope> -{{MEMORY_VISIBILITY_POLICY}} +Memory is search-space-scoped; do not assume cross-workspace visibility. </visibility_scope> <available_tools> @@ -53,10 +53,8 @@ Return **only** one JSON object (no markdown/prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` -> `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` -> `next_step` must be non-null. -- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null. +<include snippet="output_contract_base"/> +Route-specific rules: - `evidence.memory_category` is a semantic classification for supervisor logs only. It is not the persisted storage format and must not force inline `[fact|preference|instruction]` markers into saved memory. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md index f1a22ddf1..1b9ccaefa 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/system_prompt.md @@ -8,7 +8,6 @@ Gather and synthesize evidence using SurfSense research tools with clear citatio <available_tools> - `web_search` - `scrape_webpage` -- `search_surfsense_docs` </available_tools> <tool_policy> @@ -46,10 +45,8 @@ Return **only** one JSON object (no markdown/prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` -> `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` -> `next_step` must be non-null. -- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null. +<include snippet="output_contract_base"/> +Route-specific rules: - `evidence.findings`: max 10 entries, each a single sentence stating one distinct fact. Do not paste raw paragraphs, scraped pages, or quote blocks. - `evidence.sources`: max 10 URLs, one per finding when applicable. List each URL once. </output_contract> diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py index 414cc96f4..7234942b6 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/__init__.py @@ -1,11 +1,9 @@ -"""Research-stage tools: web search, scrape, and in-product doc search.""" +"""Research-stage tools: web search and scrape.""" from .scrape_webpage import create_scrape_webpage_tool -from .search_surfsense_docs import create_search_surfsense_docs_tool from .web_search import create_web_search_tool __all__ = [ "create_scrape_webpage_tool", - "create_search_surfsense_docs_tool", "create_web_search_tool", ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py index ea544a8da..d8abce46c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py @@ -9,7 +9,6 @@ from langchain_core.tools import BaseTool from app.agents.new_chat.permissions import Ruleset from .scrape_webpage import create_scrape_webpage_tool -from .search_surfsense_docs import create_search_surfsense_docs_tool from .web_search import create_web_search_tool NAME = "research" @@ -27,5 +26,4 @@ def load_tools( available_connectors=d.get("available_connectors"), ), create_scrape_webpage_tool(firecrawl_api_key=d.get("firecrawl_api_key")), - create_search_surfsense_docs_tool(db_session=d["db_session"]), ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/search_surfsense_docs.py deleted file mode 100644 index ccc5c49e2..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/search_surfsense_docs.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Semantic search over pre-indexed in-app documentation chunks for user how-to questions.""" - -import asyncio -import json - -from langchain_core.tools import tool -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument -from app.utils.document_converters import embed_text -from app.utils.surfsense_docs import surfsense_docs_public_url - - -def format_surfsense_docs_results(results: list[tuple]) -> str: - """Format (chunk, document) rows as XML with ``doc-`` chunk IDs for citations and UI routing.""" - if not results: - return "No relevant Surfsense documentation found for your query." - - # Group chunks by document - grouped: dict[int, dict] = {} - for chunk, doc in results: - public_url = surfsense_docs_public_url(doc.source) - if doc.id not in grouped: - grouped[doc.id] = { - "document_id": f"doc-{doc.id}", - "document_type": "SURFSENSE_DOCS", - "title": doc.title, - "url": public_url, - "metadata": {"source": doc.source, "public_url": public_url}, - "chunks": [], - } - grouped[doc.id]["chunks"].append( - { - "chunk_id": f"doc-{chunk.id}", - "content": chunk.content, - } - ) - - # Render XML matching format_documents_for_context structure - parts: list[str] = [] - for g in grouped.values(): - metadata_json = json.dumps(g["metadata"], ensure_ascii=False) - - parts.append("<document>") - parts.append("<document_metadata>") - parts.append(f" <document_id>{g['document_id']}</document_id>") - parts.append(f" <document_type>{g['document_type']}</document_type>") - parts.append(f" <title><![CDATA[{g['title']}]]>") - parts.append(f" ") - parts.append(f" ") - parts.append("") - parts.append("") - parts.append("") - - for ch in g["chunks"]: - parts.append( - f" " - ) - - parts.append("") - parts.append("") - parts.append("") - - return "\n".join(parts).strip() - - -async def search_surfsense_docs_async( - query: str, - db_session: AsyncSession, - top_k: int = 10, -) -> str: - """ - Search Surfsense documentation using vector similarity. - - Args: - query: The search query about Surfsense usage - db_session: Database session for executing queries - top_k: Number of results to return - - Returns: - Formatted string with relevant documentation content - """ - # Get embedding for the query - query_embedding = await asyncio.to_thread(embed_text, query) - - # Vector similarity search on chunks, joining with documents - stmt = ( - select(SurfsenseDocsChunk, SurfsenseDocsDocument) - .join( - SurfsenseDocsDocument, - SurfsenseDocsChunk.document_id == SurfsenseDocsDocument.id, - ) - .order_by(SurfsenseDocsChunk.embedding.op("<=>")(query_embedding)) - .limit(top_k) - ) - - result = await db_session.execute(stmt) - rows = result.all() - - return format_surfsense_docs_results(rows) - - -def create_search_surfsense_docs_tool(db_session: AsyncSession): - """ - Factory function to create the search_surfsense_docs tool. - - Args: - db_session: Database session for executing queries - - Returns: - A configured tool function for searching Surfsense documentation - """ - - @tool - async def search_surfsense_docs(query: str, top_k: int = 10) -> str: - """ - Search Surfsense documentation for help with using the application. - - Use this tool when the user asks questions about: - - How to use Surfsense features - - Installation and setup instructions - - Configuration options and settings - - Troubleshooting common issues - - Available connectors and integrations - - Browser extension usage - - API documentation - - This searches the official Surfsense documentation that was indexed - at deployment time. It does NOT search the user's personal knowledge base. - - Args: - query: The search query about Surfsense usage or features - top_k: Number of documentation chunks to retrieve (default: 10) - - Returns: - Relevant documentation content formatted with chunk IDs for citations - """ - return await search_surfsense_docs_async( - query=query, - db_session=db_session, - top_k=top_k, - ) - - return search_surfsense_docs diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md index 9434db7a1..e6a639af3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/system_prompt.md @@ -92,12 +92,12 @@ Return **only** one JSON object (no markdown, no prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + +Route-specific rules: - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: base, table, field, choice, record, etc.). - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (record id, primary-field value, and 1-2 most relevant fields; up to 10 entries, then `"...and N more"`). + + Discover before you mutate; never guess identifiers, choice IDs, or required fields. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md index a663f5b37..9168f4d2b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/system_prompt.md @@ -111,11 +111,12 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + +Route-specific rules: - For `search_calendar_events` results, set `evidence.items` to `{ "total": N }` and list the matched events in `action_summary` (title, date, start time; up to 10 entries, then `"...and N more"`). - For ambiguous matches across `update_calendar_event` / `delete_calendar_event`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`, where `label` should include the event title and start time for human readability). + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md index 898197f14..029609670 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/system_prompt.md @@ -93,12 +93,12 @@ Return **only** one JSON object (no markdown, no prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + +Route-specific rules: - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: task, list, member, status, custom-field choice, etc.). - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (task id, title, status, assignees; up to 10 entries, then `"...and N more"`). + + Discover before you mutate; never guess identifiers, list statuses, or assignees. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md index 991ec3d03..5aa687cd0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/system_prompt.md @@ -100,9 +100,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md index 249f9ec8b..aaabd2ac3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/system_prompt.md @@ -108,9 +108,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Resolve before you call; verify before you send; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md index a963b0ec6..8e498dfdf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/system_prompt.md @@ -98,9 +98,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md index c04d69ad0..02aff5589 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/system_prompt.md @@ -110,11 +110,12 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + +Route-specific rules: - For `search_gmail` results, set `evidence.items` to `{ "total": N }` and list the matched emails in `action_summary` (sender, subject, date; up to 10 entries, then `"...and N more"`). - For ambiguous matches across `update_gmail_draft` / `trash_gmail_email` / `read_gmail_email`, populate `evidence.matched_candidates` with up to 5 options (`id` + `label`). + + Infer before you call; verify before you send; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py index 578233b57..0680e51cb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py @@ -5,12 +5,16 @@ from datetime import datetime from email.mime.text import MIMEText from typing import Any +from langchain.tools import ToolRuntime from langchain_core.tools import tool +from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) @@ -26,9 +30,10 @@ def create_send_gmail_email_tool( to: str, subject: str, body: str, + runtime: ToolRuntime, cc: str | None = None, bcc: str | None = None, - ) -> dict[str, Any]: + ) -> Command: """Send an email via Gmail. Use when the user explicitly asks to send an email. This sends the @@ -60,11 +65,34 @@ def create_send_gmail_email_tool( """ logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'") + def _emit( + payload: dict[str, Any], + *, + success: bool, + external_id: str | None = None, + error: str | None = None, + ) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="gmail", + type="message", + operation="send", + status="success" if success else "failed", + external_id=external_id, + preview=f"to={to}: {subject}"[:200], + error=error, + ), + tool_call_id=runtime.tool_call_id, + ) + if db_session is None or search_space_id is None or user_id is None: - return { - "status": "error", - "message": "Gmail tool not properly configured. Please contact support.", - } + msg = "Gmail tool not properly configured. Please contact support." + return _emit( + {"status": "error", "message": msg}, + success=False, + error=msg, + ) try: metadata_service = GmailToolMetadataService(db_session) @@ -74,16 +102,24 @@ def create_send_gmail_email_tool( if "error" in context: logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} + return _emit( + {"status": "error", "message": context["error"]}, + success=False, + error=context["error"], + ) accounts = context.get("accounts", []) if accounts and all(a.get("auth_expired") for a in accounts): logger.warning("All Gmail accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "gmail", - } + return _emit( + { + "status": "auth_error", + "message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.", + "connector_type": "gmail", + }, + success=False, + error="auth_expired", + ) logger.info( f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'" @@ -103,10 +139,14 @@ def create_send_gmail_email_tool( ) if result.rejected: - return { - "status": "rejected", - "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", - } + return _emit( + { + "status": "rejected", + "message": "User declined. The email was not sent. Do not ask again or suggest alternatives.", + }, + success=False, + error="user_rejected", + ) final_to = result.params.get("to", to) final_subject = result.params.get("subject", subject) @@ -135,10 +175,14 @@ def create_send_gmail_email_tool( ) connector = result.scalars().first() if not connector: - return { - "status": "error", - "message": "Selected Gmail connector is invalid or has been disconnected.", - } + msg = ( + "Selected Gmail connector is invalid or has been disconnected." + ) + return _emit( + {"status": "error", "message": msg}, + success=False, + error=msg, + ) actual_connector_id = connector.id else: result = await db_session.execute( @@ -150,10 +194,12 @@ def create_send_gmail_email_tool( ) connector = result.scalars().first() if not connector: - return { - "status": "error", - "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", - } + msg = "No Gmail connector found. Please connect Gmail in your workspace settings." + return _emit( + {"status": "error", "message": msg}, + success=False, + error=msg, + ) actual_connector_id = connector.id logger.info( @@ -166,10 +212,12 @@ def create_send_gmail_email_tool( ): cca_id = connector.config.get("composio_connected_account_id") if not cca_id: - return { - "status": "error", - "message": "Composio connected account ID not found for this Gmail connector.", - } + msg = "Composio connected account ID not found for this Gmail connector." + return _emit( + {"status": "error", "message": msg}, + success=False, + error=msg, + ) from app.services.composio_service import ComposioService @@ -187,7 +235,11 @@ def create_send_gmail_email_tool( bcc=final_bcc, ) if error: - return {"status": "error", "message": error} + return _emit( + {"status": "error", "message": error}, + success=False, + error=error, + ) sent = {"id": sent_message_id, "threadId": sent_thread_id} else: from google.oauth2.credentials import Credentials @@ -275,11 +327,15 @@ def create_send_gmail_email_tool( actual_connector_id, exc_info=True, ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } + return _emit( + { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + }, + success=False, + error="insufficient_permissions", + ) raise logger.info( @@ -310,12 +366,16 @@ def create_send_gmail_email_tool( logger.warning(f"KB sync after send failed: {kb_err}") kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync." - return { - "status": "success", - "message_id": sent.get("id"), - "thread_id": sent.get("threadId"), - "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", - } + return _emit( + { + "status": "success", + "message_id": sent.get("id"), + "thread_id": sent.get("threadId"), + "message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}", + }, + success=True, + external_id=sent.get("id"), + ) except Exception as e: from langgraph.errors import GraphInterrupt @@ -324,9 +384,11 @@ def create_send_gmail_email_tool( raise logger.error(f"Error sending Gmail email: {e}", exc_info=True) - return { - "status": "error", - "message": "Something went wrong while sending the email. Please try again.", - } + msg = "Something went wrong while sending the email. Please try again." + return _emit( + {"status": "error", "message": msg}, + success=False, + error=str(e), + ) return send_gmail_email diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md index b78e1f7c6..10140d842 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/system_prompt.md @@ -100,9 +100,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md index 4dcc56454..d7816dead 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/system_prompt.md @@ -111,12 +111,12 @@ Return **only** one JSON object (no markdown, no prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + +Route-specific rules: - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: site, project, issue, user, transition, etc.). - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (issue key, summary, status, assignee; up to 10 entries, then `"...and N more"`). + + Discover before you mutate; never guess identifiers, transitions, or required fields. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md index 1d96a4105..5dfd29112 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/system_prompt.md @@ -101,12 +101,12 @@ Return **only** one JSON object (no markdown, no prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + +Route-specific rules: - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: issue, user, project, state, etc.). - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (identifier, title, state, assignee; up to 10 entries, then `"...and N more"`). + + Discover before you mutate; never guess identifiers. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md index 0f42161b3..e483789d5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/system_prompt.md @@ -101,9 +101,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; verify before you create; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md index b38c30167..909c72471 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/system_prompt.md @@ -99,9 +99,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py index 85d0ef22e..c98b25811 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py @@ -1,12 +1,16 @@ import logging from typing import Any +from langchain.tools import ToolRuntime from langchain_core.tools import tool +from langgraph.types import Command from sqlalchemy.ext.asyncio import AsyncSession from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) +from app.agents.shared.receipt import make_receipt +from app.agents.shared.receipt_command import with_receipt from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion.tool_metadata_service import NotionToolMetadataService @@ -35,8 +39,9 @@ def create_delete_notion_page_tool( @tool async def delete_notion_page( page_title: str, + runtime: ToolRuntime, delete_from_kb: bool = False, - ) -> dict[str, Any]: + ) -> Command: """Delete (archive) a Notion page. Use this tool when the user asks you to delete, remove, or archive @@ -65,14 +70,39 @@ def create_delete_notion_page_tool( f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}" ) + def _emit( + payload: dict[str, Any], + *, + status: str, + external_id: str | None = None, + error: str | None = None, + ) -> Command: + return with_receipt( + payload=payload, + receipt=make_receipt( + route="notion", + type="page", + operation="delete", + status="success" if status == "success" else "failed", + external_id=external_id, + preview=page_title, + error=error, + ), + tool_call_id=runtime.tool_call_id, + ) + if db_session is None or search_space_id is None or user_id is None: logger.error( "Notion tool not properly configured - missing required parameters" ) - return { - "status": "error", - "message": "Notion tool not properly configured. Please contact support.", - } + return _emit( + { + "status": "error", + "message": "Notion tool not properly configured. Please contact support.", + }, + status="error", + error="Notion tool not properly configured. Please contact support.", + ) try: # Get page context (page_id, account, title) from indexed data @@ -86,16 +116,18 @@ def create_delete_notion_page_tool( # Check if it's a "not found" error (softer handling for LLM) if "not found" in error_msg.lower(): logger.warning(f"Page not found: {error_msg}") - return { - "status": "not_found", - "message": error_msg, - } + return _emit( + {"status": "not_found", "message": error_msg}, + status="error", + error=error_msg, + ) else: logger.error(f"Failed to fetch delete context: {error_msg}") - return { - "status": "error", - "message": error_msg, - } + return _emit( + {"status": "error", "message": error_msg}, + status="error", + error=error_msg, + ) account = context.get("account", {}) if account.get("auth_expired"): @@ -103,10 +135,14 @@ def create_delete_notion_page_tool( "Notion account %s has expired authentication", account.get("id"), ) - return { - "status": "auth_error", - "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", - } + return _emit( + { + "status": "auth_error", + "message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.", + }, + status="error", + error="auth_expired", + ) page_id = context.get("page_id") connector_id_from_context = account.get("id") @@ -129,10 +165,14 @@ def create_delete_notion_page_tool( if result.rejected: logger.info("Notion page deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } + return _emit( + { + "status": "rejected", + "message": "User declined. Do not retry or suggest alternatives.", + }, + status="error", + error="user_rejected", + ) final_page_id = result.params.get("page_id", page_id) final_connector_id = result.params.get( @@ -165,18 +205,26 @@ def create_delete_notion_page_tool( logger.error( f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" ) - return { - "status": "error", - "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", - } + return _emit( + { + "status": "error", + "message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.", + }, + status="error", + error="invalid_connector", + ) actual_connector_id = connector.id logger.info(f"Validated Notion connector: id={actual_connector_id}") else: logger.error("No connector found for this page") - return { - "status": "error", - "message": "No connector found for this page.", - } + return _emit( + { + "status": "error", + "message": "No connector found for this page.", + }, + status="error", + error="no_connector", + ) # Create connector instance notion_connector = NotionHistoryConnector( @@ -232,7 +280,13 @@ def create_delete_notion_page_tool( f"{result.get('message', '')} (also removed from knowledge base)" ) - return result + status = result.get("status", "error") + return _emit( + result, + status=status, + external_id=str(final_page_id) if final_page_id else None, + error=None if status == "success" else result.get("message"), + ) except Exception as e: from langgraph.errors import GraphInterrupt @@ -245,20 +299,28 @@ def create_delete_notion_page_tool( if isinstance(e, NotionAPIError) and ( "401" in error_str or "unauthorized" in error_str ): - return { - "status": "auth_error", - "message": str(e), - "connector_id": connector_id_from_context - if "connector_id_from_context" in dir() - else None, - "connector_type": "notion", - } + return _emit( + { + "status": "auth_error", + "message": str(e), + "connector_id": connector_id_from_context + if "connector_id_from_context" in dir() + else None, + "connector_type": "notion", + }, + status="error", + error=str(e), + ) if isinstance(e, ValueError | NotionAPIError): message = str(e) else: message = ( "Something went wrong while deleting the page. Please try again." ) - return {"status": "error", "message": message} + return _emit( + {"status": "error", "message": message}, + status="error", + error=message, + ) return delete_notion_page diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md index 8ae444a58..4b45b05a9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/system_prompt.md @@ -97,9 +97,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Infer before you call; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md index 3c24b19c9..e4e0d1f6f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/system_prompt.md @@ -87,12 +87,12 @@ Return **only** one JSON object (no markdown, no prose): "missing_fields": string[] | null, "assumptions": string[] | null } -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + +Route-specific rules: - For blocked ambiguity, populate `evidence.matched_candidates` with up to 5 options (`id` + `label` — works for any kind of candidate: channel, user, message, thread). - For discovery-only queries (lists), set `evidence.items` to `{ "total": N }` and list the matched items in `action_summary` (channel/user, key identifier, timestamp, short snippet; up to 10 entries, then `"...and N more"`). + + Discover before you post; never guess channel, user, or thread targets. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md index c3a280f79..9b283acf5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/system_prompt.md @@ -115,9 +115,8 @@ Return **only** one JSON object (no markdown or prose outside it): } ``` -Rules: -- `status=success` → `next_step=null`, `missing_fields=null`. -- `status=partial|blocked|error` → `next_step` must be non-null. -- `status=blocked` due to missing required inputs → `missing_fields` must be non-null. + + + Resolve before you call; verify before you send; map every tool outcome faithfully. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py index 8729ea85b..2f7e3cd35 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py @@ -49,6 +49,7 @@ def request_approval( params: dict[str, Any], context: dict[str, Any] | None = None, trusted_tools: list[str] | None = None, + tool_call_id: str | None = None, ) -> HITLResult: """Pause the graph for user approval and return the user's decision. @@ -64,6 +65,10 @@ def request_approval( forwarded verbatim to the FE for richer card chrome. trusted_tools: Per-session allowlist; when ``tool_name`` is in it the interrupt is skipped and the tool runs immediately. + tool_call_id: Caller's LangChain tool-call id. Required for tools + running directly on the main agent; subagent-mounted tools omit + it (the ``task`` chokepoint stamps it on re-raise — see + :mod:`...checkpointed_subagent_middleware.propagation`). Returns: :class:`HITLResult` with ``rejected=True`` if the user declined or @@ -90,6 +95,8 @@ def request_approval( interrupt_type=action_type, context=context, ) + if tool_call_id: + payload["tool_call_id"] = tool_call_id approval = interrupt(payload) parsed = parse_lc_envelope(approval) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py index 2fce413a6..5694e4326 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/md_file_reader.py @@ -2,8 +2,11 @@ from __future__ import annotations +from functools import lru_cache from importlib import resources +_SHARED_SNIPPETS_PACKAGE = "app.agents.multi_agent_chat.subagents.shared.snippets" + def read_md_file(package: str, stem: str) -> str: """Load ``{stem}.md`` from ``package`` via importlib resources, or return empty.""" @@ -12,3 +15,13 @@ def read_md_file(package: str, stem: str) -> str: return "" text = ref.read_text(encoding="utf-8") return text.rstrip("\n") + + +@lru_cache(maxsize=64) +def read_shared_snippet(name: str) -> str: + """Load a shared markdown snippet from the snippets package. + + Cached because snippets are static at runtime and resolved many times + (once per subagent build, plus per-subagent-per-route). + """ + return read_md_file(_SHARED_SNIPPETS_PACKAGE, name) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/__init__.py new file mode 100644 index 000000000..802a8e241 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/__init__.py @@ -0,0 +1,6 @@ +"""Shared markdown snippets composed into every subagent system prompt. + +Resolved at build time by :func:`pack_subagent` in ``subagent_builder.py`` +via the ```` directive. See ``output_contract_base.md`` +and ``verifiable_handle.md`` for the included content. +""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/output_contract_base.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/output_contract_base.md new file mode 100644 index 000000000..100daae75 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/output_contract_base.md @@ -0,0 +1,6 @@ +Rules (universal): +- `status=success` -> `next_step=null`, `missing_fields=null`. +- `status=partial|blocked|error` -> `next_step` must be non-null. +- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null. +- `assumptions`: any inferences you made about the user's intent; `null` when no inferences were needed. +- The `evidence` object's fields are documented in your route-specific `` above; never invent fields the tool did not return. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md new file mode 100644 index 000000000..bea070ce9 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/snippets/verifiable_handle.md @@ -0,0 +1,10 @@ + +Mutating tools you call return a structured `Receipt` object alongside their normal payload (see `evidence.receipts` in your ``). The supervisor uses the Receipt's `verifiable_url` and `external_id` to independently confirm the operation succeeded - do not paraphrase, shorten, or guess these values. + +Rules: +- Quote each Receipt's `verifiable_url` and `external_id` **verbatim** in `evidence.receipts`. Copy character-for-character; never retype from memory. +- If a Receipt has `status="failed"`, set your own `status="error"` and put the Receipt's `error` field in `next_step`. +- If a Receipt has `status="pending"` (async backends — podcasts, video presentations, anything queued through Celery), report `status=success`, surface the pending Receipt as-is, and tell the supervisor in `action_summary` that the artefact is **being generated in the background** (e.g. "Podcast 38 queued; orchestrator should report it as kicked off, not yet ready"). A pending Receipt almost always lacks `verifiable_url` because the artefact does not exist yet — that is expected, not a defect. Do **not** wait, poll, or retry; control returns to the supervisor immediately and the asset becomes visible to the user out of band via its own UI surface. +- Never claim a mutation succeeded without a matching Receipt with `status="success"` or `"pending"` in your tool results this turn. +- For tools that do not return a Receipt (read-only operations, search, lookup), the receipt rules do not apply; only the route-specific `evidence` fields matter. + diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py index 797ab535b..f891f94d2 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py @@ -2,12 +2,30 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from dataclasses import dataclass +from typing import Any from deepagents import SubAgent from app.agents.new_chat.permissions import Ruleset +# A context-hint provider receives the parent-agent ``runtime.state`` mapping +# and the ``description`` the orchestrator wrote, and returns a short string +# the runtime prepends to the subagent's first ``HumanMessage``. Used for +# things like "current search-space id is X" or "the user is in workspace Y" — +# never for full corpora, since the prepended text consumes the subagent's +# prompt budget on every invocation. Return ``None`` (or an empty string) to +# skip the hint for this call. +ContextHintProvider = Callable[[Mapping[str, Any], str], str | None] + +# Custom key stashed on the deepagents ``SubAgent`` dict so the provider +# survives the trip from ``pack_subagent`` → registry → middleware → +# task_tool. ``deepagents.create_agent`` only extracts the keys it +# recognises, so an extra key here is dropped silently at compile time. +# The prefix avoids any collision with future deepagents fields. +SURF_CONTEXT_HINT_PROVIDER_KEY = "surf_context_hint_provider" + @dataclass(frozen=True, slots=True) class SurfSenseSubagentSpec: @@ -20,10 +38,22 @@ class SurfSenseSubagentSpec: layers them into the subagent's :class:`PermissionMiddleware`, so each subagent owns its own ruleset without aliasing the shared rule engine. + context_hint_provider: Optional callback invoked once per ``task(...)`` + invocation, immediately before the subagent runs. Its return + value is prepended to the subagent's first ``HumanMessage`` so + the subagent can see things it would otherwise have to discover + (active search space, KB root, current user timezone, etc.). + Kept out of the deepagents ``spec`` because that dict is forwarded + verbatim to upstream code and only recognises its own typed keys. """ spec: SubAgent ruleset: Ruleset + context_hint_provider: ContextHintProvider | None = None -__all__ = ["SurfSenseSubagentSpec"] +__all__ = [ + "SURF_CONTEXT_HINT_PROVIDER_KEY", + "ContextHintProvider", + "SurfSenseSubagentSpec", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py index 7173901f9..5025b32e7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py @@ -2,6 +2,8 @@ from __future__ import annotations +import logging +import re from typing import Any, cast from deepagents import SubAgent @@ -12,9 +14,48 @@ from langchain_core.tools import BaseTool from app.agents.multi_agent_chat.middleware.shared.permissions import ( build_permission_mw, ) -from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( + read_shared_snippet, +) +from app.agents.multi_agent_chat.subagents.shared.spec import ( + SURF_CONTEXT_HINT_PROVIDER_KEY, + ContextHintProvider, + SurfSenseSubagentSpec, +) from app.agents.new_chat.permissions import Ruleset +logger = logging.getLogger(__name__) + +# ```` directive. Matches an XML-style self-closing +# tag whose ``snippet`` attribute names a file in ``shared/snippets/``. +# Whitespace around the attribute and self-close is tolerated; the snippet +# name itself must be a bare identifier (letters / digits / underscores) so +# we never pull a path-traversal value into ``read_shared_snippet``. +_INCLUDE_DIRECTIVE_RE = re.compile( + r"[A-Za-z0-9_]+)\"\s*/>" +) + + +def _resolve_includes(prompt: str, *, subagent_name: str) -> str: + """Replace ```` directives with the snippet body. + + Unknown snippet names raise; an empty body is treated as unknown so a + typo or missing file fails loudly at startup instead of silently + shipping a broken prompt to the LLM. + """ + + def _replace(match: re.Match[str]) -> str: + name = match.group("name") + body = read_shared_snippet(name) + if not body.strip(): + raise ValueError( + f"Subagent {subagent_name!r}: unknown or empty shared " + f"snippet {name!r} referenced via ." + ) + return body + + return _INCLUDE_DIRECTIVE_RE.sub(_replace, prompt) + def _user_allowlist_for( dependencies: dict[str, Any], subagent_name: str @@ -43,6 +84,7 @@ def pack_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, + context_hint_provider: ContextHintProvider | None = None, ) -> SurfSenseSubagentSpec: """Pack the route-local pieces into one sub-agent spec + its Ruleset. @@ -68,6 +110,8 @@ def pack_subagent( msg = f"Subagent {name!r}: system_prompt is empty" raise ValueError(msg) + system_prompt = _resolve_includes(system_prompt, subagent_name=name) + flags = dependencies["flags"] user_allowlist = _user_allowlist_for(dependencies, name) subagent_rulesets: list[Ruleset] = [ruleset] @@ -99,4 +143,12 @@ def pack_subagent( } if model is not None: spec_dict["model"] = model - return SurfSenseSubagentSpec(spec=cast(SubAgent, spec_dict), ruleset=ruleset) + if context_hint_provider is not None: + # Stash the callback on the dict so it survives the trip through + # registry / middleware unpacking (both treat the spec as opaque). + spec_dict[SURF_CONTEXT_HINT_PROVIDER_KEY] = context_hint_provider + return SurfSenseSubagentSpec( + spec=cast(SubAgent, spec_dict), + ruleset=ruleset, + context_hint_provider=context_hint_provider, + ) diff --git a/surfsense_backend/app/agents/new_chat/anonymous_agent.py b/surfsense_backend/app/agents/new_chat/anonymous_agent.py new file mode 100644 index 000000000..c783d9a45 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/anonymous_agent.py @@ -0,0 +1,168 @@ +"""Minimal anonymous / free-chat agent. + +The no-login chat experience must stay dead simple: the user asks a question +and the model answers, optionally using ``web_search`` and an optionally +uploaded **read-only** document. We deliberately bypass the full SurfSense deep +agent stack (filesystem, file-intent, knowledge-base persistence, subagents, +skills, memory) because those middlewares stage or persist "documents" that an +anonymous session can never see again -- which produced phantom +"I saved it to a file" answers for free users. + +For any other SurfSense capability the model is instructed (via the system +prompt built here) to tell the user to create a free account instead of +pretending to perform the action. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from deepagents.backends import StateBackend +from langchain.agents import create_agent +from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + ToolCallLimitMiddleware, +) +from langchain_core.language_models import BaseChatModel +from langgraph.types import Checkpointer + +from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.middleware import ( + RetryAfterMiddleware, + create_surfsense_compaction_middleware, +) +from app.agents.new_chat.tools.web_search import create_web_search_tool + +# Cap how much of an uploaded document we inline into the system prompt. The +# upload endpoint allows files up to several MB, but the doc is re-sent on +# every turn and counts against the anonymous token quota, so we bound it. +_MAX_DOC_CHARS = 50_000 + + +def build_anonymous_system_prompt(anon_doc: dict[str, Any] | None = None) -> str: + """Build the system prompt for the minimal anonymous chat agent. + + The prompt keeps the assistant focused on plain Q/A + web search, inlines + any uploaded document as read-only context, and redirects every other + SurfSense feature to account registration. + """ + today = datetime.now(UTC).strftime("%A, %B %d, %Y") + + doc_section = "" + if anon_doc: + title = str(anon_doc.get("title") or "uploaded_document") + content = str(anon_doc.get("content") or "") + truncated = content[:_MAX_DOC_CHARS] + truncation_note = "" + if len(content) > _MAX_DOC_CHARS: + truncation_note = ( + "\n\n[Note: the document was truncated because it is large; " + "only the beginning is shown.]" + ) + doc_section = ( + "\n\n## Uploaded document (read-only)\n" + f'The user uploaded a document named "{title}". Its contents are ' + "provided below for reference only. You may read it and answer " + "questions about it, but you cannot modify, save, or store it.\n\n" + f'\n' + f"{truncated}{truncation_note}\n" + "" + ) + + return ( + "You are SurfSense's free AI assistant, available to everyone without " + "login.\n\n" + f"Today's date is {today}.\n\n" + "## How to help\n" + "- Answer the user's questions directly and conversationally. You are " + "a straightforward question-and-answer assistant.\n" + "- When a question needs current, real-time, or factual information " + "from the internet (news, prices, weather, recent events, live data), " + "use the `web_search` tool. Otherwise, answer directly from your own " + "knowledge.\n" + "- Be concise, accurate, and helpful. Use Markdown formatting when it " + "improves readability." + f"{doc_section}\n\n" + "## What is not available here\n" + "This is the free, no-login experience. You CANNOT save files or " + "notes, generate reports, podcasts, resumes, presentations, or images, " + "search or build a knowledge base, connect to apps (Gmail, Google " + "Drive, Notion, Slack, Calendar, Discord, and similar), set up " + "automations, or remember anything across sessions.\n\n" + "If the user asks for any of these, do NOT pretend to do them and " + "never claim you saved, created, or stored anything. Instead, briefly " + "let them know the feature requires a free SurfSense account and " + "invite them to create one at https://www.surfsense.com. Then offer to " + "help with what you can do here (answering questions and searching the " + "web)." + ) + + +async def create_anonymous_chat_agent( + *, + llm: BaseChatModel, + checkpointer: Checkpointer, + anon_session_id: str | None = None, + anon_doc: dict[str, Any] | None = None, + enable_web_search: bool = True, +): + """Create a minimal Q/A agent for anonymous / free chat. + + Unlike :func:`create_surfsense_deep_agent`, this agent has no filesystem, + file-intent, knowledge-base persistence, subagent, skills, or memory + middleware. Its only tool is ``web_search`` (when ``enable_web_search`` is + True), and any uploaded document is injected into the system prompt as + read-only context. + + Args: + llm: The chat model to use (already built by the caller). + checkpointer: LangGraph checkpointer for the ephemeral anon thread. + anon_session_id: Anonymous session id (used only for telemetry/metadata). + anon_doc: Optional ``{"title", "content"}`` for an uploaded document. + enable_web_search: When False, the agent runs as a pure LLM with no + tools (used when the user toggles web search off). + """ + tools = ( + [create_web_search_tool(search_space_id=None, available_connectors=None)] + if enable_web_search + else [] + ) + + # Reliability-only middleware. Nothing here touches the database or + # filesystem: call limits guard against loops, compaction summarises long + # histories into in-graph state, and retry handles provider rate limits. + middleware: list[Any] = [ + ModelCallLimitMiddleware(thread_limit=120, run_limit=80, exit_behavior="end"), + ] + if tools: + middleware.append( + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ) + ) + middleware.append(create_surfsense_compaction_middleware(llm, StateBackend)) + middleware.append(RetryAfterMiddleware(max_retries=3)) + + system_prompt = build_anonymous_system_prompt(anon_doc) + + agent = create_agent( + llm, + system_prompt=system_prompt, + tools=tools, + middleware=middleware, + context_schema=SurfSenseContextSchema, + checkpointer=checkpointer, + ) + return agent.with_config( + { + "recursion_limit": 40, + "metadata": { + "ls_integration": "surfsense_anonymous_chat", + "anon_session_id": anon_session_id, + }, + } + ) + + +__all__ = ["build_anonymous_system_prompt", "create_anonymous_chat_agent"] diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index a20a43a66..1b3ea3d20 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -64,6 +64,8 @@ class SurfSenseContextSchema: search_space_id: int | None = None mentioned_document_ids: list[int] = field(default_factory=list) mentioned_folder_ids: list[int] = field(default_factory=list) + mentioned_connector_ids: list[int] = field(default_factory=list) + mentioned_connectors: list[dict[str, object]] = field(default_factory=list) file_operation_contract: FileOperationContractState | None = None turn_id: str | None = None request_id: str | None = None diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 3cea051ef..27188fac3 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -104,7 +104,7 @@ class AgentFeatureFlags: # ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``, # ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``, # ``tools/teams``, ``tools/luma``, ``connected_accounts``, - # ``update_memory``, ``search_surfsense_docs``) now acquire fresh + # ``update_memory``) now acquire fresh # short-lived ``AsyncSession`` instances per call via # :data:`async_session_maker`. The factory still accepts ``db_session`` # for registry compatibility but ``del``'s it immediately — see any diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py index cc674be76..de2c94b41 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_state.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -33,9 +33,11 @@ from typing_extensions import TypedDict from app.agents.new_chat.state_reducers import ( _add_unique_reducer, _dict_merge_with_tombstones_reducer, + _int_counter_merge_reducer, _list_append_reducer, _replace_reducer, ) +from app.agents.shared.receipt import Receipt class PendingMove(TypedDict, total=False): @@ -172,6 +174,35 @@ class SurfSenseFilesystemState(FilesystemState): workspace_tree_text: NotRequired[Annotated[str, _replace_reducer]] """Pre-rendered ```` body; shared with subagents to skip re-render.""" + billable_calls: NotRequired[Annotated[dict[str, int], _int_counter_merge_reducer]] + """Per-subagent ``task(...)`` invocation counter, summed across the turn. + + Incremented by ``task_tool.py`` each time a subagent invocation + completes (single- or batch-mode). The orchestrator can read this map + to self-limit when a runaway loop sends the same specialist 20 calls + in a row; the runtime emits a soft warning ToolMessage once the + cumulative count crosses :data:`DEFAULT_SUBAGENT_BILLABLE_THRESHOLD`. + Cleared by checkpoint rollover (i.e. per turn). + """ + + receipts: NotRequired[Annotated[list[Receipt], _list_append_reducer]] + """Structured Receipt handles emitted by mutating subagent tools this turn. + + Each mutating tool (deliverables, every connector, KB writes via the + persistence middleware) wraps its native return into a + :class:`~app.agents.shared.receipt.Receipt` + and returns it under the ``"receipt"`` key alongside its existing + payload. The subagent's tool-call middleware folds the receipt into + this list, and ``_return_command_with_state_update`` in + ``checkpointed_subagent_middleware/task_tool.py`` carries the list up + to the parent automatically (``"receipts"`` is not in + ``EXCLUDED_STATE_KEYS``). + + Append-only across the turn; cleared by checkpoint rollover. The + orchestrator reads it via the ```` teaching to confirm + side-effecting subagent claims (see ``shared/snippets/verifiable_handle.md``). + """ + __all__ = [ "KbAnonDoc", diff --git a/surfsense_backend/app/agents/new_chat/mention_resolver.py b/surfsense_backend/app/agents/new_chat/mention_resolver.py index 00bb7e71f..f13dbc6ae 100644 --- a/surfsense_backend/app/agents/new_chat/mention_resolver.py +++ b/surfsense_backend/app/agents/new_chat/mention_resolver.py @@ -73,9 +73,8 @@ class ResolvedMentionSet: ``@Project Roadmap`` is never shadowed by a shorter prefix ``@Project``). - ``mentioned_document_ids`` collapses doc + surfsense_doc chips into - a single ordered, deduped list because the priority middleware - treats them uniformly downstream — see + ``mentioned_document_ids`` is an ordered, deduped list consumed by + the priority middleware downstream — see ``KnowledgePriorityMiddleware._compute_priority_paths``. """ @@ -103,7 +102,6 @@ async def resolve_mentions( search_space_id: int, mentioned_documents: list[MentionedDocumentInfo] | None, mentioned_document_ids: list[int] | None = None, - mentioned_surfsense_doc_ids: list[int] | None = None, mentioned_folder_ids: list[int] | None = None, ) -> ResolvedMentionSet: """Resolve every @-mention chip on a turn into virtual paths. @@ -111,8 +109,7 @@ async def resolve_mentions( The function takes both the ``mentioned_documents`` discriminated list (chip metadata used for substitution + persistence) and the parallel id arrays (``mentioned_document_ids``, - ``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two - reasons: + ``mentioned_folder_ids``) for two reasons: * Legacy clients that haven't migrated to the unified chip list still send the id arrays — we treat the union as authoritative. @@ -134,7 +131,7 @@ async def resolve_mentions( kind = chip.kind if kind == "folder": chip_folder_ids.append(chip.id) - else: + elif kind == "doc": chip_doc_ids.append(chip.id) chip_titles_by_id[(kind, chip.id)] = chip.title @@ -142,7 +139,6 @@ async def resolve_mentions( dict.fromkeys( [ *(mentioned_document_ids or []), - *(mentioned_surfsense_doc_ids or []), *chip_doc_ids, ] ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/new_chat/middleware/compaction.py index 16361e16b..f8d340e5d 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/compaction.py +++ b/surfsense_backend/app/agents/new_chat/middleware/compaction.py @@ -34,7 +34,7 @@ from deepagents.middleware.summarization import ( ) from langchain_core.messages import SystemMessage -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot if TYPE_CHECKING: from deepagents.backends.protocol import BACKEND_TYPES @@ -178,6 +178,7 @@ class SurfSenseCompactionMiddleware(SummarizationMiddleware): messages_in=len(conversation_messages), extra={"compaction.cutoff_index": int(cutoff_index)}, ): + ot_metrics.record_compaction_run(reason="auto") messages_to_summarize, preserved_messages = super()._partition_messages( conversation_messages, cutoff_index ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py index 850ecd1d2..a7901c010 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py +++ b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py @@ -47,7 +47,7 @@ from langgraph.config import get_config from langgraph.runtime import Runtime from langgraph.types import interrupt -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot logger = logging.getLogger(__name__) @@ -195,6 +195,7 @@ class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respon "interrupt.tool": (action or {}).get("tool", ""), }, ): + ot_metrics.record_interrupt(interrupt_type="permission_ask") decision = interrupt( { "type": "permission_ask", diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py index cc30f4897..c88dced85 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -55,6 +55,7 @@ from app.agents.new_chat.path_resolver import ( virtual_path_to_doc, ) from app.agents.new_chat.state_reducers import _CLEAR +from app.agents.shared.receipt import Receipt, make_receipt from app.db import ( AgentActionLog, Chunk, @@ -1392,6 +1393,81 @@ async def commit_staged_filesystem_state( "pending_dir_deletes": [_CLEAR], "dirty_path_tool_calls": {_CLEAR: True}, } + + # Emit one Receipt per committed mutation, folded into ``state['receipts']`` + # via ``_list_append_reducer``. The receipts surface what actually committed + # (post-savepoint) rather than what the LLM intended; the orchestrator uses + # them as ground truth in the ```` teaching. KB writes do not + # have public verifiable URLs, so ``verifiable_url`` stays unset. + receipts: list[Receipt] = [] + + def _kb_receipt( + *, + type: str, + operation: str, + path: str, + external_id: int | None = None, + ) -> None: + if not path: + return + preview = path.rsplit("/", 1)[-1] or path + receipts.append( + make_receipt( + route="knowledge_base", + type=type, + operation=operation, + status="success", + external_id=str(external_id) if external_id is not None else path, + preview=preview, + ) + ) + + for payload in committed_creates: + path = str(payload.get("virtualPath") or "") + _kb_receipt( + type="file", + operation="write_file", + path=path, + external_id=payload.get("id"), + ) + for payload in committed_updates: + path = str(payload.get("virtualPath") or "") + _kb_receipt( + type="file", + operation="edit_file", + path=path, + external_id=payload.get("id"), + ) + for payload in applied_moves: + # ``applied_moves`` rows carry the destination ``virtualPath`` because + # the move has already landed in the DB by the time we reach this code. + path = str(payload.get("virtualPath") or "") + _kb_receipt( + type="file", + operation="move_file", + path=path, + external_id=payload.get("id"), + ) + for path in staged_dirs: + _kb_receipt(type="folder", operation="mkdir", path=path) + for payload in committed_deletes: + path = str(payload.get("virtualPath") or "") + _kb_receipt( + type="file", + operation="rm", + path=path, + external_id=payload.get("id"), + ) + for payload in committed_folder_deletes: + path = str(payload.get("virtualPath") or "") + _kb_receipt( + type="folder", + operation="rmdir", + path=path, + external_id=payload.get("id"), + ) + if receipts: + delta["receipts"] = receipts files_delta: dict[str, Any] = {} if temp_paths: files_delta.update(dict.fromkeys(temp_paths)) diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py index cfe1edae4..ecaa042a9 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -16,13 +16,14 @@ dashboards expect. from __future__ import annotations import logging +import time from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import AIMessage, ToolMessage -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot if TYPE_CHECKING: # pragma: no cover — type-only from langchain.agents.middleware.types import ( @@ -62,14 +63,37 @@ class OtelSpanMiddleware(AgentMiddleware): return await handler(request) model_id, provider = _resolve_model_attrs(request) + t0 = time.perf_counter() with ot.model_call_span(model_id=model_id, provider=provider) as sp: + _annotate_model_request(sp, model_id=model_id, provider=provider) try: result = await handler(request) except Exception: + ot_metrics.record_model_call_duration( + (time.perf_counter() - t0) * 1000, + model=model_id, + provider=provider, + ) # span context manager records + re-raises raise else: - _annotate_model_response(sp, result) + input_tokens, output_tokens = _annotate_model_response( + sp, + result, + model_id=model_id, + provider=provider, + ) + ot_metrics.record_model_call_duration( + (time.perf_counter() - t0) * 1000, + model=model_id, + provider=provider, + ) + ot_metrics.record_model_token_usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + model=model_id, + provider=provider, + ) return result # ------------------------------------------------------------------ @@ -87,9 +111,24 @@ class OtelSpanMiddleware(AgentMiddleware): tool_name = _resolve_tool_name(request) input_size = _resolve_input_size(request) + t0 = time.perf_counter() with ot.tool_call_span(tool_name, input_size=input_size) as sp: - result = await handler(request) - _annotate_tool_result(sp, result) + try: + result = await handler(request) + except Exception: + ot_metrics.record_tool_call_duration( + (time.perf_counter() - t0) * 1000, + tool_name=tool_name, + ) + ot_metrics.record_tool_call_error(tool_name=tool_name) + raise + errored = _annotate_tool_result(sp, result) + ot_metrics.record_tool_call_duration( + (time.perf_counter() - t0) * 1000, + tool_name=tool_name, + ) + if errored: + ot_metrics.record_tool_call_error(tool_name=tool_name) return result @@ -154,8 +193,29 @@ def _resolve_input_size(request: Any) -> int | None: return None -def _annotate_model_response(span: Any, result: Any) -> None: +def _annotate_model_request( + span: Any, *, model_id: str | None, provider: str | None +) -> None: + try: + span.set_attribute("gen_ai.operation.name", "chat") + if model_id: + span.set_attribute("gen_ai.request.model", model_id) + if provider: + span.set_attribute("gen_ai.provider.name", provider) + except Exception: # pragma: no cover — defensive + pass + + +def _annotate_model_response( + span: Any, + result: Any, + *, + model_id: str | None = None, + provider: str | None = None, +) -> tuple[int | None, int | None]: """Best-effort: attach prompt/completion token counts when available.""" + input_tokens: int | None = None + output_tokens: int | None = None try: # ModelResponse may be a dataclass with .result containing AIMessage msg: Any @@ -165,22 +225,42 @@ def _annotate_model_response(span: Any, result: Any) -> None: inner = getattr(result, "result", None) msg = inner[-1] if isinstance(inner, list) and inner else inner if msg is None: - return + return None, None + if provider: + span.set_attribute("gen_ai.provider.name", provider) + if model_id: + span.set_attribute("gen_ai.request.model", model_id) + response_model = getattr(msg, "response_metadata", {}) or {} + if isinstance(response_model, dict): + response_model = ( + response_model.get("model_name") + or response_model.get("model") + or response_model.get("model_id") + ) + if not response_model: + response_model = model_id + if response_model: + span.set_attribute("gen_ai.response.model", str(response_model)) + span.set_attribute("gen_ai.operation.name", "chat") usage = getattr(msg, "usage_metadata", None) or {} if isinstance(usage, dict): if (n := usage.get("input_tokens")) is not None: - span.set_attribute("tokens.prompt", int(n)) + input_tokens = int(n) + span.set_attribute("gen_ai.usage.input_tokens", input_tokens) if (n := usage.get("output_tokens")) is not None: - span.set_attribute("tokens.completion", int(n)) + output_tokens = int(n) + span.set_attribute("gen_ai.usage.output_tokens", output_tokens) if (n := usage.get("total_tokens")) is not None: - span.set_attribute("tokens.total", int(n)) + span.set_attribute("gen_ai.usage.total_tokens", int(n)) tool_calls = getattr(msg, "tool_calls", None) or [] span.set_attribute("model.tool_calls", len(tool_calls)) except Exception: # pragma: no cover — defensive pass + return input_tokens, output_tokens -def _annotate_tool_result(span: Any, result: Any) -> None: +def _annotate_tool_result(span: Any, result: Any) -> bool: + errored = False try: if isinstance(result, ToolMessage): content = ( @@ -192,11 +272,14 @@ def _annotate_tool_result(span: Any, result: Any) -> None: status = getattr(result, "status", None) if isinstance(status, str): span.set_attribute("tool.status", status) + errored = status.lower() == "error" kwargs = getattr(result, "additional_kwargs", None) or {} if isinstance(kwargs, dict) and kwargs.get("error"): span.set_attribute("tool.error", True) + errored = True except Exception: # pragma: no cover — defensive pass + return errored __all__ = ["OtelSpanMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py index f77b7e387..07549bedb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -61,7 +61,7 @@ from app.agents.new_chat.permissions import ( aggregate_action, evaluate_many, ) -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot logger = logging.getLogger(__name__) @@ -284,6 +284,8 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] ), ot.interrupt_span(interrupt_type="permission_ask"), ): + ot_metrics.record_permission_ask(permission=tool_name) + ot_metrics.record_interrupt(interrupt_type="permission_ask") decision = interrupt(payload) return _normalize_permission_decision(decision) diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py index 0c3d3d017..321185dee 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py +++ b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py @@ -45,6 +45,8 @@ from langchain.agents.middleware.types import ( from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event from langchain_core.messages import AIMessage +from app.observability import metrics as ot_metrics, otel as ot + logger = logging.getLogger(__name__) # Names of exception classes for which a retry would not help — context @@ -198,6 +200,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp if not self._should_retry(exc) or attempt >= self.max_retries: raise delay = self._delay_for_attempt(attempt, exc) + ot.add_event( + "model.retry.scheduled", + { + "retry.attempt": attempt + 1, + "retry.max": self.max_retries, + "retry.delay_ms": int(delay * 1000), + "retry.reason": ot_metrics.categorize_exception(exc), + }, + ) try: dispatch_custom_event( "surfsense.retrying", @@ -231,6 +242,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp if not self._should_retry(exc) or attempt >= self.max_retries: raise delay = self._delay_for_attempt(attempt, exc) + ot.add_event( + "model.retry.scheduled", + { + "retry.attempt": attempt + 1, + "retry.max": self.max_retries, + "retry.delay_ms": int(delay * 1000), + "retry.reason": ot_metrics.categorize_exception(exc), + }, + ) try: await adispatch_custom_event( "surfsense.retrying", diff --git a/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py index 99eb2d74a..0294e2839 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py +++ b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any from langchain.agents.middleware import ModelFallbackMiddleware +from app.observability import metrics as ot_metrics, otel as ot + if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -55,7 +57,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware): raise last_exception = e - for fallback_model in self.models: + for attempt, fallback_model in enumerate(self.models, start=1): + ot.add_event( + "model.fallback", + { + "fallback.attempt": attempt, + "fallback.from": attempt - 1, + "fallback.to": attempt, + "fallback.reason": ot_metrics.categorize_exception(last_exception), + }, + ) try: return handler(request.override(model=fallback_model)) except Exception as e: @@ -79,7 +90,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware): raise last_exception = e - for fallback_model in self.models: + for attempt, fallback_model in enumerate(self.models, start=1): + ot.add_event( + "model.fallback", + { + "fallback.attempt": attempt, + "fallback.from": attempt - 1, + "fallback.to": attempt, + "fallback.reason": ot_metrics.categorize_exception(last_exception), + }, + ) try: return await handler(request.override(model=fallback_model)) except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md index 56291bf3e..3562ce66e 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md +++ b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md @@ -59,14 +59,13 @@ Do NOT cite document_id. Always use the chunk id. - NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format - NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only - NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess -- Copy the EXACT chunk id from the XML - if it says ``, use [citation:doc-123] +- Copy the EXACT chunk id from the XML - if it says ``, use [citation:5] - If the chunk id is a URL like ``, use [citation:https://example.com/page] CORRECT citation formats: - [citation:5] (numeric chunk ID from knowledge base) -- [citation:doc-123] (for Surfsense documentation chunks) - [citation:https://example.com/article] (URL chunk ID from web search results) - [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations) diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md index 9cc767e7e..073b75fa5 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md @@ -7,7 +7,7 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: 2. Ask the user: "Would you like me to answer from my general knowledge instead?" 3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes. - This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?"). For "how do I use SurfSense" / product-documentation questions, point the user to https://www.surfsense.com/docs. * Formatting, summarization, or analysis of content already present in the conversation * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") * Tool-usage actions like generating reports, podcasts, images, or scraping webpages diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md index 1d806dbae..1a43ed490 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md @@ -7,7 +7,7 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: 2. Ask: "Would you like me to answer from my general knowledge instead?" 3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes. - This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?"). For "how do I use SurfSense" / product-documentation questions, point the user to https://www.surfsense.com/docs. * Formatting, summarization, or analysis of content already present in the conversation * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") * Tool-usage actions like generating reports, podcasts, images, or scraping webpages diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md index b8bb069e2..9121de879 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md @@ -13,6 +13,7 @@ When to use which tool: - Knowledge base content (Notion, GitHub, files, notes) → automatically searched - Real-time public web data → call web_search - Reading a specific webpage → call scrape_webpage +- SurfSense product / how-to questions (setup, configuration, connectors, feature behavior) → point the user to the documentation: https://www.surfsense.com/docs **`task` subagents (when to delegate):** - **`linear_specialist`** — Linear-only investigations and tool use. diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md index b081a2123..c5383be77 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md @@ -13,6 +13,7 @@ When to use which tool: - Knowledge base content (Notion, GitHub, files, notes) → automatically searched - Real-time public web data → call web_search - Reading a specific webpage → call scrape_webpage +- SurfSense product / how-to questions (setup, configuration, connectors, feature behavior) → point the user to the documentation: https://www.surfsense.com/docs **`task` subagents (when to delegate):** - **`linear_specialist`** — Linear-only investigations and tool use. diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py index 42f8303e6..412665813 100644 --- a/surfsense_backend/app/agents/new_chat/prompts/composer.py +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -151,7 +151,6 @@ def _read_fragment(subpath: str) -> str: # Ordered for reading flow: fundamentals first, then artifact generators, # then memory at the end (mirrors the legacy ``_ALL_TOOL_NAMES_ORDERED``). ALL_TOOL_NAMES_ORDERED: tuple[str, ...] = ( - "search_surfsense_docs", "web_search", "generate_podcast", "generate_video_presentation", diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md deleted file mode 100644 index b90f2b7a7..000000000 --- a/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md +++ /dev/null @@ -1,9 +0,0 @@ - -- User: "How do I install SurfSense?" - - Call: `search_surfsense_docs(query="installation setup")` -- User: "What connectors does SurfSense support?" - - Call: `search_surfsense_docs(query="available connectors integrations")` -- User: "How do I set up the Notion connector?" - - Call: `search_surfsense_docs(query="Notion connector setup configuration")` -- User: "How do I use Docker to run SurfSense?" - - Call: `search_surfsense_docs(query="Docker installation setup")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md deleted file mode 100644 index 133717fec..000000000 --- a/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md +++ /dev/null @@ -1,7 +0,0 @@ - -- search_surfsense_docs: Search the official SurfSense documentation. - - Use this tool when the user asks anything about SurfSense itself (the application they are using). - - Args: - - query: The search query about SurfSense - - top_k: Number of documentation chunks to retrieve (default: 10) - - Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123]) diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md index 32e599e98..2dbc8ec43 100644 --- a/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md @@ -1,7 +1,6 @@ --- name: email-drafting description: Draft an email matching the user's voice, with structured intent and CTA -allowed-tools: search_surfsense_docs --- # Email drafting diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md index c268278ab..0f0b5ffbb 100644 --- a/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md @@ -1,7 +1,7 @@ --- name: kb-research description: Structured approach to finding and synthesizing information from the user's knowledge base -allowed-tools: search_surfsense_docs, scrape_webpage, read_file, ls_tree, grep, web_search +allowed-tools: scrape_webpage, read_file, ls_tree, grep, web_search --- # Knowledge-base research diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md index 9657eb078..5a375fbde 100644 --- a/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md @@ -1,7 +1,7 @@ --- name: meeting-prep description: Pull together briefing materials before a scheduled meeting -allowed-tools: search_surfsense_docs, web_search, scrape_webpage, read_file +allowed-tools: web_search, scrape_webpage, read_file --- # Meeting preparation diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md index 17ac2f391..cfea9593f 100644 --- a/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md @@ -1,7 +1,7 @@ --- name: report-writing description: How to scope, draft, and revise a Markdown report artifact via generate_report -allowed-tools: generate_report, search_surfsense_docs, read_file +allowed-tools: generate_report, read_file --- # Report writing diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md index 33b9e72a2..1a4c3da9f 100644 --- a/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md @@ -1,7 +1,6 @@ --- name: slack-summary description: Distill a Slack channel or thread into actionable summary -allowed-tools: search_surfsense_docs --- # Slack summarization diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py index 89fc86367..c7b7685f0 100644 --- a/surfsense_backend/app/agents/new_chat/state_reducers.py +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -171,6 +171,39 @@ def _dict_merge_with_tombstones_reducer( return result +def _int_counter_merge_reducer( + left: dict[str, int] | None, + right: dict[str, int] | None, +) -> dict[str, int]: + """Merge ``right`` into ``left`` by **summing** per-key integer counters. + + Used for state fields that accumulate counts across multiple updates + within the same turn (e.g. per-subagent ``billable_calls``). Unknown + keys are added; existing keys are summed. ``_CLEAR`` sentinels reset + the accumulator the same way the other reducers do, so the orchestrator + can wipe the counter at end-of-turn if needed. + """ + if right is None: + return dict(left or {}) + + if _CLEAR in right or any(_is_clear(k) for k in right): + result: dict[str, int] = {} + for key, value in right.items(): + if _is_clear(key): + continue + if not isinstance(value, int): + continue + result[key] = result.get(key, 0) + value + return result + + base = dict(left or {}) + for key, value in right.items(): + if not isinstance(value, int): + continue + base[key] = base.get(key, 0) + value + return base + + def _initial_filesystem_state() -> dict[str, Any]: """Default empty values for SurfSense filesystem state fields. @@ -200,6 +233,7 @@ __all__ = [ "_add_unique_reducer", "_dict_merge_with_tombstones_reducer", "_initial_filesystem_state", + "_int_counter_merge_reducer", "_list_append_reducer", "_replace_reducer", ] diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py index b993d2b06..2cfd47441 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -46,7 +46,6 @@ logger = logging.getLogger(__name__) # ``glob``, ``grep``) plus the SurfSense-side read tools. EXPLORE_READ_TOOLS: frozenset[str] = frozenset( { - "search_surfsense_docs", "web_search", "scrape_webpage", "read_file", @@ -61,7 +60,6 @@ EXPLORE_READ_TOOLS: frozenset[str] = frozenset( # is needed, the parent should hand off to ``explore`` first. REPORT_WRITER_TOOLS: frozenset[str] = frozenset( { - "search_surfsense_docs", "read_file", "generate_report", } @@ -222,7 +220,6 @@ EXPLORE_SYSTEM_PROMPT = """You are the **explore** subagent for SurfSense. Conduct read-only research across the user's knowledge base, the web, and any documents the parent agent has surfaced. Return a synthesized answer with explicit citations — never speculate beyond the sources you have actually inspected. ## Tools available -- `search_surfsense_docs` — fast hybrid search over the user's knowledge base. - `web_search` — only when the user's KB clearly does not contain the answer. - `scrape_webpage` — to read a URL the user or the search results provided. - `read_file`, `ls`, `glob`, `grep` — to inspect specific documents or trees the parent has flagged. @@ -242,7 +239,7 @@ Produce a single high-quality report deliverable using `generate_report`. The pa ## Workflow 1. **Outline first.** Before calling `generate_report`, write a one-paragraph outline of the sections you plan to produce. Confirm the outline reflects the parent's instructions. -2. **Source resolution.** Decide whether to call `search_surfsense_docs` and `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set. +2. **Source resolution.** Decide whether to call `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set. 3. **One report.** Call `generate_report` exactly once with `source_strategy` chosen per the topic and chat history (see the `report-writing` skill). 4. **Confirm.** End with a one-sentence summary in your final message — never paste the report back into chat; the artifact card renders itself. """ diff --git a/surfsense_backend/app/agents/new_chat/tools/__init__.py b/surfsense_backend/app/agents/new_chat/tools/__init__.py index bc444b0c0..4b5ae3706 100644 --- a/surfsense_backend/app/agents/new_chat/tools/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/__init__.py @@ -5,7 +5,6 @@ This module contains all the tools available to the SurfSense agent. To add a new tool, see the documentation in registry.py. Available tools: -- search_surfsense_docs: Search Surfsense documentation for usage help - generate_podcast: Generate audio podcasts from content - generate_video_presentation: Generate video presentations with slides and narration - generate_image: Generate images from text descriptions using AI models @@ -31,7 +30,6 @@ from .registry import ( get_tool_by_name, ) from .scrape_webpage import create_scrape_webpage_tool -from .search_surfsense_docs import create_search_surfsense_docs_tool from .update_memory import create_update_memory_tool, create_update_team_memory_tool from .video_presentation import create_generate_video_presentation_tool @@ -47,7 +45,6 @@ __all__ = [ "create_generate_podcast_tool", "create_generate_video_presentation_tool", "create_scrape_webpage_tool", - "create_search_surfsense_docs_tool", "create_update_memory_tool", "create_update_team_memory_tool", "format_documents_for_context", diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py index 2c9b7fa0c..83ac98768 100644 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ b/surfsense_backend/app/agents/new_chat/tools/podcast.py @@ -2,17 +2,23 @@ Podcast generation tool for the SurfSense agent. This module provides a factory function for creating the generate_podcast tool -that submits a Celery task for background podcast generation. The frontend -polls for completion and auto-updates when the podcast is ready. +that submits a Celery task for background podcast generation. The tool then +polls the podcast row until it reaches a terminal status (READY/FAILED) and +returns that status. The wait is bounded by the chat's HTTP / process +lifetime; see app.agents.shared.deliverable_wait for details. """ +import logging from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.shared.deliverable_wait import wait_for_deliverable from app.db import Podcast, PodcastStatus, shielded_async_session +logger = logging.getLogger(__name__) + def create_generate_podcast_tool( search_space_id: int, @@ -97,18 +103,53 @@ def create_generate_podcast_tool( user_prompt=user_prompt, ) - print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}") + logger.info( + "[generate_podcast] Created podcast %s, task: %s", + podcast_id, + task.id, + ) + # Wait until the Celery worker flips the row to a terminal + # state. No internal budget — see deliverable_wait module. + terminal_status, columns, elapsed = await wait_for_deliverable( + model=Podcast, + row_id=podcast_id, + columns=[Podcast.status, Podcast.file_location], + terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, + ) + + if terminal_status == PodcastStatus.READY: + file_location = columns[1] if columns else None + logger.info( + "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", + podcast_id, + elapsed, + file_location, + ) + return { + "status": PodcastStatus.READY.value, + "podcast_id": podcast_id, + "title": podcast_title, + "file_location": file_location, + "message": ("Podcast generated and saved to your podcast panel."), + } + + # Only other terminal state is FAILED. + logger.warning( + "[generate_podcast] Podcast %s FAILED in %.2fs", + podcast_id, + elapsed, + ) return { - "status": PodcastStatus.PENDING.value, + "status": PodcastStatus.FAILED.value, "podcast_id": podcast_id, "title": podcast_title, - "message": "Podcast generation started. This may take a few minutes.", + "error": ("Background worker reported FAILED status for this podcast."), } except Exception as e: error_message = str(e) - print(f"[generate_podcast] Error: {error_message}") + logger.exception("[generate_podcast] Error: %s", error_message) return { "status": PodcastStatus.FAILED.value, "error": error_message, diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index b842d7a20..6f011e372 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -101,7 +101,6 @@ from .podcast import create_generate_podcast_tool from .report import create_generate_report_tool from .resume import create_generate_resume_tool from .scrape_webpage import create_scrape_webpage_tool -from .search_surfsense_docs import create_search_surfsense_docs_tool from .teams import ( create_list_teams_channels_tool, create_read_teams_messages_tool, @@ -150,6 +149,28 @@ class ToolDefinition: reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None +# ============================================================================= +# Deferred-import factories +# ============================================================================= +# Used for tools whose impls live under ``multi_agent_chat``. Importing those +# at module-load time would cycle (``multi_agent_chat`` middleware imports +# this registry). The import inside the factory runs only when +# ``build_tools`` is called, by which point ``multi_agent_chat`` is fully +# initialised. + + +def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool: + from app.agents.multi_agent_chat.main_agent.tools.automation import ( + create_create_automation_tool, + ) + + return create_create_automation_tool( + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + llm=deps["llm"], + ) + + # ============================================================================= # Built-in Tools Registry # ============================================================================= @@ -236,15 +257,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=[], ), - # Surfsense documentation search tool - ToolDefinition( - name="search_surfsense_docs", - description="Search Surfsense documentation for help with using the application", - factory=lambda deps: create_search_surfsense_docs_tool( - db_session=deps["db_session"], - ), - requires=["db_session"], - ), # ========================================================================= # SERVICE ACCOUNT DISCOVERY # Generic tool for the LLM to discover connected accounts and resolve @@ -261,6 +273,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ requires=["db_session", "search_space_id", "user_id"], ), # ========================================================================= + # AUTOMATION AUTHORING - single HITL tool. The tool takes an NL ``intent`` + # from the main agent, drafts the full AutomationCreate JSON via a focused + # sub-LLM, surfaces it on an approval card, and persists on approval. The + # factory defers its import because the impl lives under ``multi_agent_chat`` + # and that package transitively pulls this registry via middleware; + # deferring to ``build_tools`` call-time breaks the cycle without a + # parallel registry. + # ========================================================================= + ToolDefinition( + name="create_automation", + description="Draft an automation from an NL intent; user approves the card; tool saves", + factory=_build_create_automation_tool, + requires=["search_space_id", "user_id", "llm"], + ), + # ========================================================================= # MEMORY TOOL - single update_memory, private or team by thread_visibility # ========================================================================= ToolDefinition( diff --git a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py b/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py deleted file mode 100644 index d8a0efac7..000000000 --- a/surfsense_backend/app/agents/new_chat/tools/search_surfsense_docs.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Surfsense documentation search tool. - -This tool allows the agent to search the pre-indexed Surfsense documentation -to help users with questions about how to use the application. - -The documentation is indexed at deployment time from MDX files and stored -in dedicated tables (surfsense_docs_documents, surfsense_docs_chunks). -""" - -import asyncio -import json - -from langchain_core.tools import tool -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker -from app.utils.document_converters import embed_text -from app.utils.surfsense_docs import surfsense_docs_public_url - - -def format_surfsense_docs_results(results: list[tuple]) -> str: - """ - Format search results into XML structure for the LLM context. - - Uses the same XML structure as format_documents_for_context from knowledge_base.py - but with 'doc-' prefix on chunk IDs. This allows: - - LLM to use consistent [citation:doc-XXX] format - - Frontend to detect 'doc-' prefix and route to surfsense docs endpoint - - Args: - results: List of (chunk, document) tuples from the database query - - Returns: - Formatted XML string with documentation content and citation-ready chunks - """ - if not results: - return "No relevant Surfsense documentation found for your query." - - # Group chunks by document - grouped: dict[int, dict] = {} - for chunk, doc in results: - public_url = surfsense_docs_public_url(doc.source) - if doc.id not in grouped: - grouped[doc.id] = { - "document_id": f"doc-{doc.id}", - "document_type": "SURFSENSE_DOCS", - "title": doc.title, - "url": public_url, - "metadata": {"source": doc.source, "public_url": public_url}, - "chunks": [], - } - grouped[doc.id]["chunks"].append( - { - "chunk_id": f"doc-{chunk.id}", - "content": chunk.content, - } - ) - - # Render XML matching format_documents_for_context structure - parts: list[str] = [] - for g in grouped.values(): - metadata_json = json.dumps(g["metadata"], ensure_ascii=False) - - parts.append("") - parts.append("") - parts.append(f" {g['document_id']}") - parts.append(f" {g['document_type']}") - parts.append(f" <![CDATA[{g['title']}]]>") - parts.append(f" ") - parts.append(f" ") - parts.append("") - parts.append("") - parts.append("") - - for ch in g["chunks"]: - parts.append( - f" " - ) - - parts.append("") - parts.append("") - parts.append("") - - return "\n".join(parts).strip() - - -async def search_surfsense_docs_async( - query: str, - db_session: AsyncSession, - top_k: int = 10, -) -> str: - """ - Search Surfsense documentation using vector similarity. - - Args: - query: The search query about Surfsense usage - db_session: Database session for executing queries - top_k: Number of results to return - - Returns: - Formatted string with relevant documentation content - """ - # Get embedding for the query - query_embedding = await asyncio.to_thread(embed_text, query) - - # Vector similarity search on chunks, joining with documents - stmt = ( - select(SurfsenseDocsChunk, SurfsenseDocsDocument) - .join( - SurfsenseDocsDocument, - SurfsenseDocsChunk.document_id == SurfsenseDocsDocument.id, - ) - .order_by(SurfsenseDocsChunk.embedding.op("<=>")(query_embedding)) - .limit(top_k) - ) - - result = await db_session.execute(stmt) - rows = result.all() - - return format_surfsense_docs_results(rows) - - -def create_search_surfsense_docs_tool(db_session: AsyncSession): - """ - Factory function to create the search_surfsense_docs tool. - - The tool acquires its own short-lived ``AsyncSession`` per call via - :data:`async_session_maker` so the closure is safe to share across - HTTP requests by the compiled-agent cache. Capturing a per-request - session here would surface stale/closed sessions on cache hits. - - Args: - db_session: Reserved for registry compatibility. Per-call sessions - are opened via :data:`async_session_maker` inside the tool body. - - Returns: - A configured tool function for searching Surfsense documentation - """ - del db_session # per-call session — see docstring - - @tool - async def search_surfsense_docs(query: str, top_k: int = 10) -> str: - """ - Search Surfsense documentation for help with using the application. - - Use this tool when the user asks questions about: - - How to use Surfsense features - - Installation and setup instructions - - Configuration options and settings - - Troubleshooting common issues - - Available connectors and integrations - - Browser extension usage - - API documentation - - This searches the official Surfsense documentation that was indexed - at deployment time. It does NOT search the user's personal knowledge base. - - Args: - query: The search query about Surfsense usage or features - top_k: Number of documentation chunks to retrieve (default: 10) - - Returns: - Relevant documentation content formatted with chunk IDs for citations - """ - async with async_session_maker() as db_session: - return await search_surfsense_docs_async( - query=query, - db_session=db_session, - top_k=top_k, - ) - - return search_surfsense_docs diff --git a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py index 7bf9a1c3b..34f5183ca 100644 --- a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py +++ b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py @@ -2,17 +2,23 @@ Video presentation generation tool for the SurfSense agent. This module provides a factory function for creating the generate_video_presentation -tool that submits a Celery task for background video presentation generation. -The frontend polls for completion and auto-updates when the presentation is ready. +tool that submits a Celery task for background video presentation generation. The +tool then polls the row until it reaches a terminal status (READY/FAILED) and +returns that status. The wait is bounded by the chat's HTTP / process lifetime; +see app.agents.shared.deliverable_wait for details. """ +import logging from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.shared.deliverable_wait import wait_for_deliverable from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session +logger = logging.getLogger(__name__) + def create_generate_video_presentation_tool( search_space_id: int, @@ -72,20 +78,56 @@ def create_generate_video_presentation_tool( user_prompt=user_prompt, ) - print( - f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}" + logger.info( + "[generate_video_presentation] Created video presentation %s, task: %s", + video_pres_id, + task.id, ) + # Wait until the Celery worker flips the row to a terminal + # state. No internal budget — see deliverable_wait module. + terminal_status, _columns, elapsed = await wait_for_deliverable( + model=VideoPresentation, + row_id=video_pres_id, + columns=[VideoPresentation.status], + terminal_statuses={ + VideoPresentationStatus.READY, + VideoPresentationStatus.FAILED, + }, + ) + + if terminal_status == VideoPresentationStatus.READY: + logger.info( + "[generate_video_presentation] %s READY in %.2fs", + video_pres_id, + elapsed, + ) + return { + "status": VideoPresentationStatus.READY.value, + "video_presentation_id": video_pres_id, + "title": video_title, + "message": "Video presentation generated and saved.", + } + + # Only other terminal state is FAILED. + logger.warning( + "[generate_video_presentation] %s FAILED in %.2fs", + video_pres_id, + elapsed, + ) return { - "status": VideoPresentationStatus.PENDING.value, + "status": VideoPresentationStatus.FAILED.value, "video_presentation_id": video_pres_id, "title": video_title, - "message": "Video presentation generation started. This may take a few minutes.", + "error": ( + "Background worker reported FAILED status for this " + "video presentation." + ), } except Exception as e: error_message = str(e) - print(f"[generate_video_presentation] Error: {error_message}") + logger.exception("[generate_video_presentation] Error: %s", error_message) return { "status": VideoPresentationStatus.FAILED.value, "error": error_message, diff --git a/surfsense_backend/app/agents/shared/__init__.py b/surfsense_backend/app/agents/shared/__init__.py new file mode 100644 index 000000000..7c46c65ff --- /dev/null +++ b/surfsense_backend/app/agents/shared/__init__.py @@ -0,0 +1,9 @@ +"""Cross-package agent contracts. + +Symbols here are intentionally framework-light (no LangGraph / deepagents +internals) so they can be imported from both ``app.agents.new_chat`` and +``app.agents.multi_agent_chat`` without creating a circular dependency +between the two packages. See ``receipt.py`` for the rationale. +""" + +from __future__ import annotations diff --git a/surfsense_backend/app/agents/shared/deliverable_wait.py b/surfsense_backend/app/agents/shared/deliverable_wait.py new file mode 100644 index 000000000..abaa017ea --- /dev/null +++ b/surfsense_backend/app/agents/shared/deliverable_wait.py @@ -0,0 +1,123 @@ +"""Shared poll-until-terminal helper for Celery-backed deliverables. + +Lives in ``app.agents.shared`` (neutral package, no dependencies on either +``new_chat`` or ``multi_agent_chat``) so both the flat single-agent tools +under ``app/agents/new_chat/tools/`` and the multi-agent subagent tools +under ``app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/`` +can import it without creating a circular dependency. + +Background +---------- +Tools like ``generate_podcast`` and ``generate_video_presentation`` enqueue +the heavy work to Celery and historically returned immediately with a +"pending" status. That works for very-long deliverables but hurts UX for +the common case (most podcasts finish in 10-30 seconds): the agent sends +a "kicked off, check back in a minute" reply *before* the worker is done, +so the user never gets a "ready" confirmation. + +This helper bridges that gap. The tool dispatches the Celery task as +before, then polls the artefact row's ``status`` column **until it +reaches a terminal value** (READY / FAILED). The tool then returns a +real terminal outcome — never a pending one. + +No wall-clock budget here on purpose +------------------------------------ +Layering a second budget on top of the existing per-invocation safety +nets just confused the UX. The real ceilings are: + +* **Multi-agent mode** — ``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` + (default ``300.0``, ``0`` to disable) caps how long any single + ``task(subagent, ...)`` invocation can run. If a deliverable needs + longer than this, the subagent invocation is cancelled and the + orchestrator surfaces a "subagent timed out" ToolMessage. Operators + who routinely generate long videos should raise that ceiling (or set + it to ``0`` for true unbounded waits). +* **Single-agent mode** — the chat's HTTP stream / process lifetime is + the only ceiling. Truly indefinite waits work here, but a dead Celery + worker will leave the row in PENDING/GENERATING forever; treat that + as an operational concern, not a UX concern. + +Configuration +------------- +None. The poll cadence is hardcoded at 1.5s — small enough to feel +responsive (~6 polls per typical 10s podcast), large enough to avoid +hammering the DB under burst traffic. Override at the call site if a +specific tool needs a different cadence. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from enum import Enum +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import InstrumentedAttribute + +from app.db import shielded_async_session + +logger = logging.getLogger(__name__) + + +_DEFAULT_POLL_INTERVAL_SECONDS: float = 1.5 + + +async def wait_for_deliverable( + *, + model: type, + row_id: int, + columns: list[InstrumentedAttribute[Any]], + terminal_statuses: set[Enum], + poll_interval_s: float = _DEFAULT_POLL_INTERVAL_SECONDS, +) -> tuple[Enum, tuple[Any, ...], float]: + """Poll ``model`` row ``row_id`` until ``columns[0]`` reaches a terminal status. + + Blocks until the row's status column matches one of + ``terminal_statuses``. There is no internal wall-clock budget; cancel + from the outside (subagent timeout, HTTP disconnect, task + cancellation) if you need a ceiling. See module docstring. + + The first entry of ``columns`` must be the status column; additional + columns (e.g. ``Podcast.file_location``) are returned alongside the + final status so callers can build their payload without a second + roundtrip. + + A fresh ``shielded_async_session`` is opened per poll so we never + hold a transaction across the wait, and a failed poll is logged but + does not abort the wait — transient DB hiccups should not collapse + the tool call. + + Returns + ------- + ``(terminal_status, columns, elapsed_seconds)`` + ``columns`` mirrors the requested ``columns`` (including the + status itself in position 0). + """ + if not columns: + raise ValueError("wait_for_deliverable requires at least the status column") + + start = time.monotonic() + + while True: + await asyncio.sleep(poll_interval_s) + row: tuple[Any, ...] | None = None + try: + async with shielded_async_session() as session: + result = await session.execute( + select(*columns).where(model.id == row_id) + ) + row = result.first() + except Exception as exc: + logger.warning( + "[deliverable_wait] poll failed model=%s id=%s err=%r", + getattr(model, "__name__", str(model)), + row_id, + exc, + ) + + if row is not None: + status_val = row[0] + if status_val in terminal_statuses: + return status_val, tuple(row), time.monotonic() - start diff --git a/surfsense_backend/app/agents/shared/receipt.py b/surfsense_backend/app/agents/shared/receipt.py new file mode 100644 index 000000000..6f30067ee --- /dev/null +++ b/surfsense_backend/app/agents/shared/receipt.py @@ -0,0 +1,161 @@ +"""Receipt: structured handle returned by every mutating subagent tool. + +Generalises the Hermes ``entry`` dict (see ``references/hermes-agent/tools/ +delegate_tool.py:1663-1697``) for our 5 deliverable types + 15 connectors + +KB writes. The supervisor reads the Receipt to verify what actually happened +without round-tripping through LLM paraphrase. + +**Why this lives under ``app.agents.shared`` and not under either of the +two agent packages:** the Receipt is a *contract* shared between +``multi_agent_chat`` (where mutating tools emit it) and ``new_chat`` +(where ``filesystem_state.SurfSenseFilesystemState`` declares the +``receipts`` reducer that accumulates it, and where +``middleware.kb_persistence`` emits its own KB-write receipts). Putting +the contract in either package would create a bidirectional import +between the two — see the commit that introduced this module for the +``ImportError`` chain it broke. + +Each mutating tool wraps its native return shape into a Receipt via +:func:`make_receipt` (or builds one directly) and returns it under the +``"receipt"`` key alongside its existing payload. The subagent boundary +machinery in ``checkpointed_subagent_middleware.task_tool`` then folds +the receipt into the parent's ``receipts`` state via the append reducer. + +The KB write path is the one exception: file-tool calls cannot emit a +durable receipt because the actual DB writes happen end-of-turn inside +:class:`app.agents.new_chat.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`. +KB tools therefore emit a *provisional* receipt with ``status="pending"``; +the persistence middleware flips it to ``"success"`` or ``"failed"`` +before returning control to the parent. +""" + +from __future__ import annotations + +from typing import Any, Literal, TypedDict + +# Subagent that emitted this receipt. +ReceiptRoute = Literal[ + "deliverables", + "knowledge_base", + "notion", + "slack", + "gmail", + "linear", + "jira", + "clickup", + "confluence", + "calendar", + "luma", + "airtable", + "google_drive", + "dropbox", + "onedrive", + "discord", + "teams", +] + +# Within-route kind of artefact / external resource the operation touched. +# Left as ``str`` rather than a giant union so each route file documents +# its own enum next to its tools. +ReceiptType = str + +# Operation verb. Kept open for the same reason as ``ReceiptType``. +ReceiptOperation = str + +# Pending = async backend (Celery podcast / video) that the orchestrator +# will surface progress for out of band; persistence-MW flipped this to +# ``success`` for KB writes that committed. +ReceiptStatus = Literal["success", "pending", "failed"] + + +class Receipt(TypedDict, total=False): + """Structured per-mutation handle returned to the parent subagent. + + All fields are ``NotRequired`` (TypedDict ``total=False``) so each + route's tool can populate only the fields it actually has — e.g. Gmail + never sets ``verifiable_url`` because Gmail doesn't expose per-message + URLs. The receipts state reducer treats missing keys as missing rather + than ``null`` so we don't double-count. + """ + + route: ReceiptRoute + """Subagent name. Lets the orchestrator filter ``state['receipts']`` + by route without re-deriving from ``type``.""" + + type: ReceiptType + """Within-route kind. e.g. for ``deliverables`` one of ``{report, + podcast, video_presentation, resume, image}``; for ``notion`` ``page``; + for ``slack`` ``message``.""" + + operation: ReceiptOperation + """Verb. e.g. ``generate`` (deliverables), ``create`` / ``update`` / + ``delete`` (most connectors), ``send`` / ``post`` (chat), ``write_file`` + / ``edit_file`` / ``rm`` / ``rmdir`` / ``move_file`` / ``mkdir`` (KB).""" + + status: ReceiptStatus + """``success`` / ``pending`` / ``failed``. The verification teaching + in ``shared/snippets/verifiable_handle.md`` keys off this field.""" + + external_id: str | None + """Backend identifier. Report row id, Notion ``page_id``, Slack ``ts``, + Gmail ``message_id``, Linear identifier, KB ``virtualPath``, etc. + ``None`` only when the operation failed before the backend assigned one.""" + + verifiable_url: str | None + """URL the parent can pass to ``scrape_webpage`` to verify the + operation. ``None`` when no public URL exists (Gmail, KB, raw images + stored in the DB).""" + + preview: str | None + """Short snippet (~200 chars) of what was produced. First lines of + a generated report's markdown, transcript opener for a podcast, + thumbnail URL for an image. Lets the orchestrator decide whether to + re-render in the UI without re-loading the artefact.""" + + error: str | None + """Filled iff ``status == "failed"``. Plain-text reason; the parent + surfaces it in its own ``next_step``.""" + + +def make_receipt( + *, + route: ReceiptRoute, + type: str, + operation: str, + status: ReceiptStatus, + external_id: str | None = None, + verifiable_url: str | None = None, + preview: str | None = None, + error: str | None = None, +) -> Receipt: + """Construct a :class:`Receipt` with non-``None`` fields only. + + Drops keys whose value is ``None`` so downstream consumers can use + ``"verifiable_url" in receipt`` to distinguish "tool returned no URL" + from "tool deliberately surfaced ``null``". + """ + out: dict[str, Any] = { + "route": route, + "type": type, + "operation": operation, + "status": status, + } + if external_id is not None: + out["external_id"] = external_id + if verifiable_url is not None: + out["verifiable_url"] = verifiable_url + if preview is not None: + out["preview"] = preview + if error is not None: + out["error"] = error + return out # type: ignore[return-value] + + +__all__ = [ + "Receipt", + "ReceiptOperation", + "ReceiptRoute", + "ReceiptStatus", + "ReceiptType", + "make_receipt", +] diff --git a/surfsense_backend/app/agents/shared/receipt_command.py b/surfsense_backend/app/agents/shared/receipt_command.py new file mode 100644 index 000000000..f1c269e90 --- /dev/null +++ b/surfsense_backend/app/agents/shared/receipt_command.py @@ -0,0 +1,71 @@ +"""Helper for wrapping a tool result with a Receipt in a ``Command(update=...)``. + +Most mutating subagent tools historically returned a plain ``dict`` payload +which deepagents serialised straight into the ``ToolMessage`` content. To +participate in the verification teaching from +``multi_agent_chat/subagents/shared/snippets/verifiable_handle.md`` those +tools now also need to write a :class:`Receipt` into the parent's +``state['receipts']`` list (declared on +:class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState` +and backed by the append reducer). + +:func:`with_receipt` wraps both behaviours: it returns the tool payload as +a JSON-encoded ``ToolMessage`` AND appends the receipt to state in a single +:class:`~langgraph.types.Command`. Use it at every ``return`` site of a +mutating tool — including failure paths (emit a receipt with +``status="failed"`` and the error message in ``error``). +""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain_core.messages import ToolMessage +from langgraph.types import Command + +from app.agents.shared.receipt import Receipt + + +def _content_to_text(payload: dict[str, Any] | str) -> str: + """Serialise a tool payload to ``ToolMessage`` content. + + Dicts go through ``json.dumps`` (matching deepagents' default tool-result + serialisation); strings are passed through. Anything else is coerced via + ``str`` so we never raise here — a mis-typed tool return would already + have failed inside the tool body. + """ + if isinstance(payload, str): + return payload + if isinstance(payload, dict): + return json.dumps(payload, default=str) + return str(payload) + + +def with_receipt( + *, + payload: dict[str, Any] | str, + receipt: Receipt, + tool_call_id: str, +) -> Command: + """Return a Command that ships ``payload`` as a ToolMessage AND appends ``receipt``. + + The append happens via the ``_list_append_reducer`` on the ``receipts`` + field of :class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState`, + so concurrent subagent batches (item 4 in the plan) won't clobber each + other's receipts. + """ + return Command( + update={ + "messages": [ + ToolMessage( + content=_content_to_text(payload), + tool_call_id=tool_call_id, + ) + ], + "receipts": [receipt], + } + ) + + +__all__ = ["with_receipt"] diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index fc6242643..9bd637ba6 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import gc import logging import time @@ -36,13 +37,15 @@ from app.config import ( ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError +from app.observability import metrics as ot_metrics +from app.observability.bootstrap import init_otel, shutdown_otel from app.rate_limiter import get_real_client_ip, limiter from app.routes import router as crud_router from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate -from app.tasks.surfsense_docs_indexer import seed_surfsense_docs +from app.session_events import register_session_hooks from app.users import SECRET, auth_backend, current_active_user, fastapi_users -from app.utils.perf import get_perf_logger, log_system_snapshot +from app.utils.perf import log_system_snapshot _error_logger = logging.getLogger("surfsense.errors") @@ -127,6 +130,8 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons logged server-side. """ rid = _get_request_id(request) + if exc.status_code in {401, 403} and request.url.path.startswith("/auth"): + ot_metrics.record_auth_failure(reason=_status_to_code(exc.status_code)) should_sanitize = exc.status_code == 500 # Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."}) @@ -213,6 +218,7 @@ def _validation_error_handler( def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Catch-all: log full traceback, return sanitized 500.""" rid = _get_request_id(request) + ot_metrics.record_auth_failure(reason="unhandled_exception") _error_logger.error( "[%s] Unhandled exception on %s %s", rid, @@ -246,6 +252,7 @@ def _status_to_code(status_code: int, detail: str = "") -> str: def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): """Custom 429 handler that returns JSON matching our error envelope.""" rid = _get_request_id(request) + ot_metrics.record_rate_limit_rejection(scope="slowapi") retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60" return _build_error_response( 429, @@ -306,6 +313,7 @@ def _check_rate_limit_memory( f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} " f"({len(timestamps)}/{max_requests} in {window_seconds}s)" ) + ot_metrics.record_rate_limit_rejection(scope=scope) raise HTTPException( status_code=429, detail="RATE_LIMIT_EXCEEDED", @@ -349,6 +357,7 @@ def _check_rate_limit( f"Rate limit exceeded on {scope} for IP {client_ip} " f"({current_count}/{max_requests} in {window_seconds}s)" ) + ot_metrics.record_rate_limit_rejection(scope=scope) raise HTTPException( status_code=429, detail="RATE_LIMIT_EXCEEDED", @@ -558,6 +567,7 @@ async def lifespan(app: FastAPI): gc.set_threshold(700, 10, 5) _enable_slow_callback_logging(threshold_sec=0.5) + init_otel(app) await create_db_and_tables() await setup_checkpointer_tables() initialize_openrouter_integration() @@ -566,13 +576,6 @@ async def lifespan(app: FastAPI): initialize_llm_router() initialize_image_gen_router() initialize_vision_llm_router() - try: - await asyncio.wait_for(seed_surfsense_docs(), timeout=120) - except TimeoutError: - logging.getLogger(__name__).warning( - "Surfsense docs seeding timed out after 120s — skipping. " - "Docs will be indexed on the next restart." - ) # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup @@ -586,12 +589,14 @@ async def lifespan(app: FastAPI): "first real request will pay the full compile cost." ) + register_session_hooks() log_system_snapshot("startup_complete") yield _stop_openrouter_background_refresh() await close_checkpointer() + shutdown_otel() def registration_allowed(): @@ -676,32 +681,20 @@ class RequestPerfMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: StarletteRequest, call_next: RequestResponseEndpoint ) -> StarletteResponse: - perf = get_perf_logger() t0 = time.perf_counter() response = await call_next(request) elapsed_ms = (time.perf_counter() - t0) * 1000 path = request.url.path - method = request.method - status = response.status_code - - perf.debug( - "[request] %s %s -> %d in %.1fms", - method, - path, - status, - elapsed_ms, - ) if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD: - perf.warning( - "[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)", - method, - path, - status, - elapsed_ms, - _PERF_SLOW_REQUEST_THRESHOLD, - ) + with contextlib.suppress(Exception): + from opentelemetry import trace + + span = trace.get_current_span() + span.set_attribute("slow_request", True) + span.set_attribute("surfsense.request.elapsed_ms", elapsed_ms) + span.set_attribute("http.route", path) log_system_snapshot("slow_request") return response diff --git a/surfsense_backend/app/automations/__init__.py b/surfsense_backend/app/automations/__init__.py new file mode 100644 index 000000000..a4ce8ecc9 --- /dev/null +++ b/surfsense_backend/app/automations/__init__.py @@ -0,0 +1,5 @@ +"""Automations engine — see automation-design-plan.md.""" + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/surfsense_backend/app/automations/actions/__init__.py b/surfsense_backend/app/automations/actions/__init__.py new file mode 100644 index 000000000..ac5a07ac4 --- /dev/null +++ b/surfsense_backend/app/automations/actions/__init__.py @@ -0,0 +1,24 @@ +"""Actions domain: registry surface + built-in action packages. + +Each action lives in its own subpackage (``agent_task/``, ...) and self-registers +at import time via its ``definition`` module. Side-effect imports below ensure +the registry is populated whenever anyone touches the actions package. +""" + +from __future__ import annotations + +from .store import all_actions, get_action, register_action +from .types import ActionContext, ActionDefinition, ActionHandler, ActionHandlerFactory + +__all__ = [ + "ActionContext", + "ActionDefinition", + "ActionHandler", + "ActionHandlerFactory", + "all_actions", + "get_action", + "register_action", +] + +# Built-in actions self-register at import time. +from . import builtin # noqa: F401 diff --git a/surfsense_backend/app/automations/actions/builtin/__init__.py b/surfsense_backend/app/automations/actions/builtin/__init__.py new file mode 100644 index 000000000..f3d21a044 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/__init__.py @@ -0,0 +1,5 @@ +"""Built-in action types — each in its own subpackage, self-registering at import.""" + +from __future__ import annotations + +from . import agent_task # noqa: F401 diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/__init__.py b/surfsense_backend/app/automations/actions/builtin/agent_task/__init__.py new file mode 100644 index 000000000..3a42a2815 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/__init__.py @@ -0,0 +1,15 @@ +"""``agent_task`` action: spin up multi_agent_chat for one rendered query. + +Imports ``definition`` for its side-effect (self-registration on the actions +registry) and re-exports ``build_handler`` for direct consumers. +""" + +from __future__ import annotations + +from .factory import build_handler +from .params import AgentTaskActionParams + +__all__ = ["AgentTaskActionParams", "build_handler"] + +# Side-effect: register on the actions store. +from . import definition # noqa: F401 diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/auto_decide.py b/surfsense_backend/app/automations/actions/builtin/agent_task/auto_decide.py new file mode 100644 index 000000000..357eeb565 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/auto_decide.py @@ -0,0 +1,39 @@ +"""Synthesize HITL decisions for every pending interrupt (approve-all or reject-all).""" + +from __future__ import annotations + +from typing import Any + + +def build_auto_decisions( + state: Any, decision: str +) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]: + """Return ``(lg_resume_map, surfsense_resume_value)`` covering every pending interrupt. + + ``lg_resume_map`` is keyed by ``Interrupt.id`` for ``Command(resume=...)``; + ``surfsense_resume_value`` is keyed by ``tool_call_id`` for the subagent + middleware bridge. Action count is read from ``value.action_requests`` when + present and falls back to ``1`` for wrapped scalar interrupts. + """ + lg_resume_map: dict[str, dict[str, Any]] = {} + routed: dict[str, dict[str, Any]] = {} + + for interrupt_obj in getattr(state, "interrupts", ()) or (): + value = getattr(interrupt_obj, "value", None) + if not isinstance(value, dict): + continue + interrupt_id = getattr(interrupt_obj, "id", None) + if not isinstance(interrupt_id, str): + continue + + action_requests = value.get("action_requests") + count = len(action_requests) if isinstance(action_requests, list) else 1 + decisions = [{"type": decision} for _ in range(count)] + + lg_resume_map[interrupt_id] = {"decisions": decisions} + + tool_call_id = value.get("tool_call_id") + if isinstance(tool_call_id, str): + routed[tool_call_id] = {"decisions": decisions} + + return lg_resume_map, routed diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/definition.py b/surfsense_backend/app/automations/actions/builtin/agent_task/definition.py new file mode 100644 index 000000000..cc3fd563a --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/definition.py @@ -0,0 +1,18 @@ +"""``agent_task`` ``ActionDefinition`` registration.""" + +from __future__ import annotations + +from ...store import register_action +from ...types import ActionDefinition +from .factory import build_handler +from .params import AgentTaskActionParams + +AGENT_TASK_ACTION = ActionDefinition( + type="agent_task", + name="Agent task", + description="Run a multi_agent_chat turn from an automation step.", + params_model=AgentTaskActionParams, + build_handler=build_handler, +) + +register_action(AGENT_TASK_ACTION) diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py new file mode 100644 index 000000000..e3736cc95 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -0,0 +1,112 @@ +"""Build the per-invocation dependencies the multi_agent_chat factory needs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from langgraph.checkpoint.memory import InMemorySaver +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, +) +from app.db import SearchSpace +from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle +from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( + setup_connector_and_firecrawl, +) + + +class DependencyError(Exception): + """An external dependency (LLM config, connector service, ...) refused to load.""" + + +@dataclass(frozen=True, slots=True) +class AgentDependencies: + """Everything ``create_multi_agent_chat_deep_agent`` needs from the environment.""" + + llm: Any + agent_config: Any + connector_service: Any + firecrawl_api_key: str | None + checkpointer: Any + + +async def build_dependencies( + *, + session: AsyncSession, + search_space_id: int, + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, +) -> AgentDependencies: + """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. + + Resolves the agent LLM from the automation's *captured* model snapshot + (``agent_llm_id``) so runs are insulated from later chat/search-space model + changes. The model policy is enforced here as a runtime backstop: a captured + model that is no longer billable (e.g. a premium global config was removed) + fails the run clearly instead of silently consuming a free model. + + When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``agent_llm_id`` and validate that. + """ + if agent_llm_id is not None: + try: + assert_models_billable( + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, + ) + except AutomationModelPolicyError as exc: + raise DependencyError(str(exc)) from exc + resolved_agent_llm_id = agent_llm_id or 0 + else: + search_space = await session.get(SearchSpace, search_space_id) + if search_space is None: + raise DependencyError(f"search space {search_space_id} not found") + try: + assert_automation_models_billable(search_space) + except AutomationModelPolicyError as exc: + raise DependencyError(str(exc)) from exc + resolved_agent_llm_id = search_space.agent_llm_id or 0 + + llm, agent_config, err = await load_llm_bundle( + session, + config_id=resolved_agent_llm_id, + search_space_id=search_space_id, + ) + if err is not None or llm is None: + raise DependencyError(err or "failed to load agent LLM config") + + connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( + session, search_space_id=search_space_id + ) + # Quick fix: use an in-memory checkpointer for automation runs. + # + # The shared Postgres checkpointer caches DB connections in a + # module-level pool. Each cached connection is bound to the asyncio + # loop that opened it. Celery throws away the loop after every task, + # so the pool ends up full of connections pointing to a dead loop, + # and the next Celery task (running on a fresh loop) can't use any + # of them — it hangs 30s and fails with + # `PoolTimeout: couldn't get a connection after 30.00 sec`. + # + # InMemorySaver has no cached connections, no loop binding — each + # Celery task creates one and drops it on exit. + # + # TODO(checkpointer): proper fix is to dispose the checkpointer + # pool around each Celery task in `run_async_celery_task`, the same + # way `_dispose_shared_db_engine` already does for the SQLAlchemy + # pool. Then this site can switch back to the shared checkpointer. + checkpointer = InMemorySaver() + return AgentDependencies( + llm=llm, + agent_config=agent_config, + connector_service=connector_service, + firecrawl_api_key=firecrawl_api_key, + checkpointer=checkpointer, + ) diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/factory.py b/surfsense_backend/app/automations/actions/builtin/agent_task/factory.py new file mode 100644 index 000000000..f4f5d7d37 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/factory.py @@ -0,0 +1,28 @@ +"""Bind ``ActionContext`` to a callable that runs one ``agent_task`` step.""" + +from __future__ import annotations + +from typing import Any + +from ...types import ActionContext, ActionHandler +from .invoke import run_agent_task +from .params import AgentTaskActionParams + + +def build_handler(ctx: ActionContext) -> ActionHandler: + """Return a handler closure that validates params and runs the agent task.""" + + async def handle(params: dict[str, Any]) -> dict[str, Any]: + validated = AgentTaskActionParams.model_validate(params) + return await run_agent_task( + ctx=ctx, + query=validated.query, + auto_approve_all=validated.auto_approve_all, + mentioned_document_ids=validated.mentioned_document_ids, + mentioned_folder_ids=validated.mentioned_folder_ids, + mentioned_connector_ids=validated.mentioned_connector_ids, + mentioned_connectors=validated.mentioned_connectors, + mentioned_documents=validated.mentioned_documents, + ) + + return handle diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/finalize.py b/surfsense_backend/app/automations/actions/builtin/agent_task/finalize.py new file mode 100644 index 000000000..d5f1f95f6 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/finalize.py @@ -0,0 +1,44 @@ +"""Extract the agent's final assistant text from the terminal invoke result.""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.messages import AIMessage + + +def extract_final_assistant_message(result: Any) -> str | None: + """Return the last ``AIMessage`` text content, or ``None`` if there isn't one. + + Multi-part messages (content lists) are flattened by concatenating ``text`` + parts in order. Non-string content (tool calls, images) is skipped. + """ + if not isinstance(result, dict): + return None + messages = result.get("messages") + if not isinstance(messages, list): + return None + + for msg in reversed(messages): + if not isinstance(msg, AIMessage): + continue + return _content_to_text(msg.content) + return None + + +def _content_to_text(content: Any) -> str | None: + if isinstance(content, str): + text = content.strip() + return text or None + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + text = part.get("text") + if isinstance(text, str): + parts.append(text) + joined = "".join(parts).strip() + return joined or None + return None diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py new file mode 100644 index 000000000..99e295f30 --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -0,0 +1,227 @@ +"""Run one ``agent_task`` invocation: ainvoke + auto-decision resume loop.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from langchain_core.messages import HumanMessage +from langgraph.types import Command +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text +from app.db import ChatVisibility, async_session_maker +from app.schemas.new_chat import MentionedDocumentInfo + +from ...types import ActionContext +from .auto_decide import build_auto_decisions +from .dependencies import build_dependencies +from .finalize import extract_final_assistant_message + +# Cap on HITL resume iterations. The agent should not need this many turns in one +# step; treat overshoot as a runaway and fail the step. +_MAX_RESUMES = 50 + + +def _build_connector_block(connectors: list[dict[str, Any]]) -> str | None: + """Render the ```` context block (same shape as chat). + + Mirrors ``stream_new_chat`` so the agent gets the exact connector accounts + the user picked. Returns ``None`` when nothing renders. + """ + lines: list[str] = [] + for connector in connectors: + connector_id = connector.get("id") + connector_type = connector.get("connector_type") or connector.get( + "document_type" + ) + account_name = connector.get("account_name") or connector.get("title") + if connector_id is None or connector_type is None: + continue + lines.append( + f' - connector_id={connector_id}, connector_type="{connector_type}", ' + f'account_name="{account_name or ""}"' + ) + if not lines: + return None + return ( + "\n" + "The user selected these exact connector accounts with @. " + "These entries are selection metadata, not retrieved connector content. " + "When a connector-backed tool needs an account, use the matching " + "connector_id from this list if the tool supports connector_id:\n" + + "\n".join(lines) + + "\n" + ) + + +async def _resolve_mention_context( + session: AsyncSession, + *, + search_space_id: int, + query: str, + mentioned_document_ids: list[int] | None, + mentioned_folder_ids: list[int] | None, + mentioned_connector_ids: list[int] | None, + mentioned_connectors: list[MentionedDocumentInfo] | None, + mentioned_documents: list[MentionedDocumentInfo] | None, +) -> tuple[str, SurfSenseContextSchema | None]: + """Resolve @-mentions into a rewritten query + per-invocation context. + + Automation always runs in cloud filesystem mode, so we mirror the chat + ``new_chat`` flow: substitute ``@title`` tokens with canonical + ``/documents/...`` paths, prepend a ```` block, and + build a ``SurfSenseContextSchema`` that ``KnowledgePriorityMiddleware`` + reads via ``runtime.context``. Returns ``(query, None)`` unchanged when + there are no mentions. + """ + has_mentions = bool( + mentioned_document_ids + or mentioned_folder_ids + or mentioned_connector_ids + or mentioned_connectors + or mentioned_documents + ) + if not has_mentions: + return query, None + + resolved = await resolve_mentions( + session, + search_space_id=search_space_id, + mentioned_documents=mentioned_documents, + mentioned_document_ids=mentioned_document_ids, + mentioned_folder_ids=mentioned_folder_ids, + ) + agent_query = substitute_in_text(query, resolved.token_to_path) + + # ``SurfSenseContextSchema.mentioned_connectors`` is typed ``list[dict]`` and + # the connector block reads dicts, so dump the pydantic chips once. + connector_dicts = [c.model_dump() for c in (mentioned_connectors or [])] + connector_block = _build_connector_block(connector_dicts) + if connector_block: + agent_query = f"{connector_block}\n\n{agent_query}" + + runtime_context = SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list( + resolved.mentioned_document_ids or (mentioned_document_ids or []) + ), + mentioned_folder_ids=list( + resolved.mentioned_folder_ids or (mentioned_folder_ids or []) + ), + mentioned_connector_ids=list(mentioned_connector_ids or []), + mentioned_connectors=connector_dicts, + ) + return agent_query, runtime_context + + +async def run_agent_task( + *, + ctx: ActionContext, + query: str, + auto_approve_all: bool, + mentioned_document_ids: list[int] | None = None, + mentioned_folder_ids: list[int] | None = None, + mentioned_connector_ids: list[int] | None = None, + mentioned_connectors: list[MentionedDocumentInfo] | None = None, + mentioned_documents: list[MentionedDocumentInfo] | None = None, +) -> dict[str, Any]: + """Invoke multi_agent_chat for one rendered query and return its outcome. + + Opens its own DB session so the executor's bookkeeping session isn't tied + up for the entire invocation. The LangGraph ``thread_id`` (a fresh UUID) + is returned as ``agent_session_id`` for later inspection. + + @-mentions (files / folders / connectors) chosen in the task input are + resolved the same way the chat flow does and forwarded to the agent via the + per-invocation ``context`` so they actually scope retrieval. + """ + agent_session_id = str(uuid.uuid4()) + user_id = str(ctx.creator_user_id) if ctx.creator_user_id else None + decision = "approve" if auto_approve_all else "reject" + + async with async_session_maker() as agent_session: + deps = await build_dependencies( + session=agent_session, + search_space_id=ctx.search_space_id, + agent_llm_id=ctx.agent_llm_id, + image_generation_config_id=ctx.image_generation_config_id, + vision_llm_config_id=ctx.vision_llm_config_id, + ) + + agent = await create_multi_agent_chat_deep_agent( + llm=deps.llm, + search_space_id=ctx.search_space_id, + db_session=agent_session, + connector_service=deps.connector_service, + checkpointer=deps.checkpointer, + user_id=user_id, + thread_id=None, + agent_config=deps.agent_config, + firecrawl_api_key=deps.firecrawl_api_key, + thread_visibility=ChatVisibility.PRIVATE, + mentioned_document_ids=mentioned_document_ids, + image_generation_config_id=ctx.image_generation_config_id, + ) + + agent_query, runtime_context = await _resolve_mention_context( + agent_session, + search_space_id=ctx.search_space_id, + query=query, + mentioned_document_ids=mentioned_document_ids, + mentioned_folder_ids=mentioned_folder_ids, + mentioned_connector_ids=mentioned_connector_ids, + mentioned_connectors=mentioned_connectors, + mentioned_documents=mentioned_documents, + ) + + request_id = f"automation:{ctx.run_id}:{ctx.step_id}" + turn_id = f"{request_id}:{int(time.time() * 1000)}" + input_state: dict[str, Any] = { + "messages": [HumanMessage(content=agent_query)], + "search_space_id": ctx.search_space_id, + "request_id": request_id, + "turn_id": turn_id, + } + config: dict[str, Any] = { + "configurable": { + "thread_id": agent_session_id, + "request_id": request_id, + "turn_id": turn_id, + }, + "recursion_limit": 10_000, + } + if runtime_context is not None: + runtime_context.request_id = request_id + runtime_context.turn_id = turn_id + + # The compiled graph declares ``context_schema=SurfSenseContextSchema``; + # mentions only reach ``KnowledgePriorityMiddleware`` via ``context=``. + invoke_kwargs: dict[str, Any] = {"config": config} + if runtime_context is not None: + invoke_kwargs["context"] = runtime_context + + result = await agent.ainvoke(input_state, **invoke_kwargs) + + resumes = 0 + while True: + state = await agent.aget_state(config) + if not getattr(state, "interrupts", None): + break + if resumes >= _MAX_RESUMES: + raise RuntimeError( + f"agent_task exceeded {_MAX_RESUMES} HITL resume iterations" + ) + lg_resume_map, routed = build_auto_decisions(state, decision) + config["configurable"]["surfsense_resume_value"] = routed + result = await agent.ainvoke(Command(resume=lg_resume_map), **invoke_kwargs) + resumes += 1 + + return { + "agent_session_id": agent_session_id, + "final_message": extract_final_assistant_message(result), + "resumes": resumes, + } diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/params.py b/surfsense_backend/app/automations/actions/builtin/agent_task/params.py new file mode 100644 index 000000000..ad6f35edb --- /dev/null +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/params.py @@ -0,0 +1,52 @@ +"""``AgentTaskActionParams`` — params for the ``agent_task`` action type.""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.new_chat import MentionedDocumentInfo + + +class AgentTaskActionParams(BaseModel): + """Run a multi_agent_chat turn from an automation step.""" + + model_config = ConfigDict(extra="forbid") + + query: str = Field( + ..., + min_length=1, + description="User query for the agent; rendered at execute time.", + ) + auto_approve_all: bool = Field( + default=False, + description="If true, every HITL approval is auto-approved; otherwise rejected.", + ) + + # @-mention references chosen in the task input. Mirror the ``new_chat`` + # request fields (minus SurfSense product docs) so the run can scope + # retrieval to the user's selected files / folders / connectors. All + # optional and additive; a task with no mentions behaves as before. + mentioned_document_ids: list[int] | None = Field( + default=None, + description="Knowledge-base document IDs the task references with @.", + ) + mentioned_folder_ids: list[int] | None = Field( + default=None, + description="Knowledge-base folder IDs the task references with @.", + ) + mentioned_connector_ids: list[int] | None = Field( + default=None, + description="Concrete connector account IDs the task references with @.", + ) + mentioned_connectors: list[MentionedDocumentInfo] | None = Field( + default=None, + description="Display/context metadata for the @-mentioned connector accounts.", + ) + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Chip metadata (id, title, kind, ...) for every @-mention so the " + "run can resolve titles to virtual paths and substitute them in " + "the query." + ), + ) diff --git a/surfsense_backend/app/automations/actions/store.py b/surfsense_backend/app/automations/actions/store.py new file mode 100644 index 000000000..eff66c4c7 --- /dev/null +++ b/surfsense_backend/app/automations/actions/store.py @@ -0,0 +1,23 @@ +"""In-memory action registry. Populated once at process startup.""" + +from __future__ import annotations + +from .types import ActionDefinition + +_REGISTRY: dict[str, ActionDefinition] = {} + + +def register_action(action: ActionDefinition) -> None: + """Register an action. Raises on duplicate type.""" + if action.type in _REGISTRY: + raise ValueError(f"Action already registered: {action.type!r}") + _REGISTRY[action.type] = action + + +def get_action(action_type: str) -> ActionDefinition | None: + return _REGISTRY.get(action_type) + + +def all_actions() -> dict[str, ActionDefinition]: + """Defensive snapshot of the registry.""" + return dict(_REGISTRY) diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py new file mode 100644 index 000000000..453721a43 --- /dev/null +++ b/surfsense_backend/app/automations/actions/types.py @@ -0,0 +1,46 @@ +"""``ActionDefinition``, ``ActionContext``, and handler/factory signatures.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + + +@dataclass(frozen=True, slots=True) +class ActionContext: + """Per-invocation dependencies bound to an action handler at execute time.""" + + session: AsyncSession + run_id: int + step_id: str + search_space_id: int + creator_user_id: UUID | None + # Captured model snapshot from the automation definition (``definition.models``), + # resolved per run instead of the live search space. ``None`` falls back to the + # search space's current prefs (defensive; should not happen post-capture). + agent_llm_id: int | None = None + image_generation_config_id: int | None = None + vision_llm_config_id: int | None = None + + +ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] +ActionHandlerFactory = Callable[[ActionContext], ActionHandler] + + +@dataclass(frozen=True, slots=True) +class ActionDefinition: + type: str + name: str + description: str + params_model: type[BaseModel] + build_handler: ActionHandlerFactory + + @property + def params_schema(self) -> dict[str, Any]: + """JSON Schema (draft 2020-12) derived from ``params_model``.""" + return self.params_model.model_json_schema() diff --git a/surfsense_backend/app/automations/api/__init__.py b/surfsense_backend/app/automations/api/__init__.py new file mode 100644 index 000000000..a18e91a95 --- /dev/null +++ b/surfsense_backend/app/automations/api/__init__.py @@ -0,0 +1,16 @@ +"""HTTP layer for the automations feature.""" + +from __future__ import annotations + +from fastapi import APIRouter + +from .automation import router as automation_router +from .run import router as run_router +from .trigger import router as trigger_router + +router = APIRouter() +router.include_router(automation_router) +router.include_router(trigger_router) +router.include_router(run_router) + +__all__ = ["router"] diff --git a/surfsense_backend/app/automations/api/automation.py b/surfsense_backend/app/automations/api/automation.py new file mode 100644 index 000000000..911ae57a6 --- /dev/null +++ b/surfsense_backend/app/automations/api/automation.py @@ -0,0 +1,109 @@ +"""HTTP routes for the ``Automation`` resource.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, status +from pydantic import BaseModel + +from app.automations.schemas.api import ( + AutomationCreate, + AutomationDetail, + AutomationList, + AutomationSummary, + AutomationUpdate, +) +from app.automations.services import AutomationService, get_automation_service + +router = APIRouter() + + +class ModelEligibilityViolation(BaseModel): + kind: str + config_id: int | None + reason: str + + +class ModelEligibility(BaseModel): + allowed: bool + violations: list[ModelEligibilityViolation] + + +@router.post( + "/automations", + response_model=AutomationDetail, + status_code=status.HTTP_201_CREATED, +) +async def create_automation( + payload: AutomationCreate, + service: AutomationService = Depends(get_automation_service), +) -> AutomationDetail: + """Create an automation, optionally with initial triggers (atomic).""" + automation = await service.create(payload) + return AutomationDetail.model_validate(automation) + + +@router.get("/automations", response_model=AutomationList) +async def list_automations( + search_space_id: int = Query(...), + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + service: AutomationService = Depends(get_automation_service), +) -> AutomationList: + """List automations in a search space.""" + items, total = await service.list( + search_space_id=search_space_id, limit=limit, offset=offset + ) + return AutomationList( + items=[AutomationSummary.model_validate(a) for a in items], + total=total, + ) + + +@router.get("/automations/model-eligibility", response_model=ModelEligibility) +async def get_automation_model_eligibility( + search_space_id: int = Query(...), + service: AutomationService = Depends(get_automation_service), +) -> ModelEligibility: + """Report whether a search space's models are billable for automations. + + Used by the frontend to gate creation: automations may only use premium + global models or user BYOK models (free models and Auto mode are blocked). + + NOTE: declared before ``/automations/{automation_id}`` so the literal path + isn't captured by the int-typed ``{automation_id}`` route. + """ + result = await service.model_eligibility(search_space_id=search_space_id) + return ModelEligibility.model_validate(result) + + +@router.get("/automations/{automation_id}", response_model=AutomationDetail) +async def get_automation( + automation_id: int, + service: AutomationService = Depends(get_automation_service), +) -> AutomationDetail: + """Get one automation with its definition and triggers.""" + automation = await service.get(automation_id) + return AutomationDetail.model_validate(automation) + + +@router.patch("/automations/{automation_id}", response_model=AutomationDetail) +async def update_automation( + automation_id: int, + patch: AutomationUpdate, + service: AutomationService = Depends(get_automation_service), +) -> AutomationDetail: + """Partially update an automation. Triggers are managed separately.""" + automation = await service.update(automation_id, patch) + return AutomationDetail.model_validate(automation) + + +@router.delete( + "/automations/{automation_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def delete_automation( + automation_id: int, + service: AutomationService = Depends(get_automation_service), +) -> None: + """Delete an automation; triggers and runs are removed by FK cascade.""" + await service.delete(automation_id) diff --git a/surfsense_backend/app/automations/api/run.py b/surfsense_backend/app/automations/api/run.py new file mode 100644 index 000000000..b662a5943 --- /dev/null +++ b/surfsense_backend/app/automations/api/run.py @@ -0,0 +1,44 @@ +"""HTTP routes for automation run history.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query + +from app.automations.schemas.api import RunDetail, RunList, RunSummary +from app.automations.services import RunService, get_run_service + +router = APIRouter() + + +@router.get( + "/automations/{automation_id}/runs", + response_model=RunList, +) +async def list_runs( + automation_id: int, + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + service: RunService = Depends(get_run_service), +) -> RunList: + """List run history for an automation, newest first.""" + items, total = await service.list( + automation_id=automation_id, limit=limit, offset=offset + ) + return RunList( + items=[RunSummary.model_validate(r) for r in items], + total=total, + ) + + +@router.get( + "/automations/{automation_id}/runs/{run_id}", + response_model=RunDetail, +) +async def get_run( + automation_id: int, + run_id: int, + service: RunService = Depends(get_run_service), +) -> RunDetail: + """Get the full record of a single run, including step results and artifacts.""" + run = await service.get(automation_id=automation_id, run_id=run_id) + return RunDetail.model_validate(run) diff --git a/surfsense_backend/app/automations/api/trigger.py b/surfsense_backend/app/automations/api/trigger.py new file mode 100644 index 000000000..40e47a86b --- /dev/null +++ b/surfsense_backend/app/automations/api/trigger.py @@ -0,0 +1,55 @@ +"""HTTP routes for triggers attached to an automation.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, status + +from app.automations.schemas.api import TriggerCreate, TriggerDetail, TriggerUpdate +from app.automations.services import TriggerService, get_trigger_service + +router = APIRouter() + + +@router.post( + "/automations/{automation_id}/triggers", + response_model=TriggerDetail, + status_code=status.HTTP_201_CREATED, +) +async def add_trigger( + automation_id: int, + payload: TriggerCreate, + service: TriggerService = Depends(get_trigger_service), +) -> TriggerDetail: + """Attach a new trigger to an automation.""" + trigger = await service.add(automation_id=automation_id, payload=payload) + return TriggerDetail.model_validate(trigger) + + +@router.patch( + "/automations/{automation_id}/triggers/{trigger_id}", + response_model=TriggerDetail, +) +async def update_trigger( + automation_id: int, + trigger_id: int, + patch: TriggerUpdate, + service: TriggerService = Depends(get_trigger_service), +) -> TriggerDetail: + """Toggle ``enabled`` or replace ``params``. Trigger type is immutable.""" + trigger = await service.update( + automation_id=automation_id, trigger_id=trigger_id, patch=patch + ) + return TriggerDetail.model_validate(trigger) + + +@router.delete( + "/automations/{automation_id}/triggers/{trigger_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def remove_trigger( + automation_id: int, + trigger_id: int, + service: TriggerService = Depends(get_trigger_service), +) -> None: + """Detach a trigger from an automation.""" + await service.remove(automation_id=automation_id, trigger_id=trigger_id) diff --git a/surfsense_backend/app/automations/dispatch/__init__.py b/surfsense_backend/app/automations/dispatch/__init__.py new file mode 100644 index 000000000..bab1d122e --- /dev/null +++ b/surfsense_backend/app/automations/dispatch/__init__.py @@ -0,0 +1,8 @@ +"""Generic dispatch primitives shared across trigger types.""" + +from __future__ import annotations + +from .errors import DispatchError +from .launch import launch_run + +__all__ = ["DispatchError", "launch_run"] diff --git a/surfsense_backend/app/automations/dispatch/errors.py b/surfsense_backend/app/automations/dispatch/errors.py new file mode 100644 index 000000000..75640a987 --- /dev/null +++ b/surfsense_backend/app/automations/dispatch/errors.py @@ -0,0 +1,7 @@ +"""Dispatch errors raised when a fire request cannot be turned into a run.""" + +from __future__ import annotations + + +class DispatchError(Exception): + """A dispatch could not proceed (missing trigger, invalid inputs, ...).""" diff --git a/surfsense_backend/app/automations/dispatch/inputs.py b/surfsense_backend/app/automations/dispatch/inputs.py new file mode 100644 index 000000000..61546b993 --- /dev/null +++ b/surfsense_backend/app/automations/dispatch/inputs.py @@ -0,0 +1,43 @@ +"""Merge and validate the inputs a run starts with.""" + +from __future__ import annotations + +from typing import Any + +import jsonschema + +from app.automations.persistence.models.trigger import AutomationTrigger +from app.automations.schemas.definition.envelope import AutomationDefinition + +from .errors import DispatchError + + +def prepare_inputs( + definition: AutomationDefinition, + trigger: AutomationTrigger, + runtime_inputs: dict[str, Any] | None, +) -> dict[str, Any]: + """Merge ``trigger.static_inputs`` over ``runtime_inputs``, then validate. + + Static inputs win on key collision. + """ + merged = {**(runtime_inputs or {}), **(trigger.static_inputs or {})} + return validate_inputs(definition, merged) + + +def validate_inputs( + definition: AutomationDefinition, inputs: dict[str, Any] +) -> dict[str, Any]: + """Validate ``inputs`` against the definition's optional declared schema. + + No declared schema → pass through unchanged so runtime keys (``fired_at``, + ``last_fired_at``, ...) still reach the template context. A declared schema + that the inputs violate is surfaced as ``DispatchError``. + """ + if definition.inputs is None or not definition.inputs.schema_: + return inputs + try: + jsonschema.validate(instance=inputs, schema=definition.inputs.schema_) + except jsonschema.ValidationError as exc: + raise DispatchError(f"inputs: {exc.message}") from exc + return inputs diff --git a/surfsense_backend/app/automations/dispatch/launch.py b/surfsense_backend/app/automations/dispatch/launch.py new file mode 100644 index 000000000..cf7fb53d8 --- /dev/null +++ b/surfsense_backend/app/automations/dispatch/launch.py @@ -0,0 +1,60 @@ +"""Launch a run for a trigger that fired: resolve, validate, persist, enqueue. + +The trigger-facing entry every selector calls. A selector builds the runtime +inputs and hands one trigger row here; this resolves and guards its automation, +snapshots the definition onto a PENDING run, and enqueues execution. The +snapshot makes the run immune to later edits of the automation. +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.persistence.enums.run_status import RunStatus +from app.automations.persistence.models.run import AutomationRun +from app.automations.persistence.models.trigger import AutomationTrigger +from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.tasks.execute_run import automation_run_execute + +from .errors import DispatchError +from .inputs import prepare_inputs +from .resolve import resolve_active_automation + + +async def launch_run( + *, + session: AsyncSession, + trigger: AutomationTrigger, + runtime_inputs: dict[str, Any] | None = None, +) -> AutomationRun: + """Resolve ``trigger``'s active automation and enqueue a PENDING run for it.""" + automation = await resolve_active_automation(session, trigger) + + try: + definition = AutomationDefinition.model_validate(automation.definition) + except Exception as exc: + raise DispatchError(f"invalid automation definition: {exc}") from exc + + inputs = prepare_inputs(definition, trigger, runtime_inputs) + snapshot = definition.model_dump(mode="json", by_alias=True) + + run = AutomationRun( + automation_id=automation.id, + trigger_id=trigger.id, + status=RunStatus.PENDING, + definition_snapshot=snapshot, + inputs=inputs, + step_results=[], + artifacts=[], + ) + session.add(run) + await session.commit() + await session.refresh(run) + + automation_run_execute.apply_async( + args=[run.id], + time_limit=definition.execution.timeout_seconds, + ) + return run diff --git a/surfsense_backend/app/automations/dispatch/resolve.py b/surfsense_backend/app/automations/dispatch/resolve.py new file mode 100644 index 000000000..13efd15ee --- /dev/null +++ b/surfsense_backend/app/automations/dispatch/resolve.py @@ -0,0 +1,40 @@ +"""Resolve the automation behind a trigger and guard that it may run.""" + +from __future__ import annotations + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.persistence.enums.automation_status import AutomationStatus +from app.automations.persistence.models.automation import Automation +from app.automations.persistence.models.trigger import AutomationTrigger + +from .errors import DispatchError + + +async def resolve_active_automation( + session: AsyncSession, trigger: AutomationTrigger +) -> Automation: + """Load ``trigger``'s automation and require it ``ACTIVE``. + + Raises ``DispatchError`` if the automation is missing or not active. + """ + automation = await _load_automation(session, trigger.automation_id) + if automation is None: + raise DispatchError( + f"automation {trigger.automation_id} not found for trigger {trigger.id}" + ) + + if automation.status != AutomationStatus.ACTIVE: + raise DispatchError( + f"automation {trigger.automation_id} is {automation.status.value}, not active" + ) + + return automation + + +async def _load_automation( + session: AsyncSession, automation_id: int +) -> Automation | None: + stmt = select(Automation).where(Automation.id == automation_id) + return (await session.execute(stmt)).scalar_one_or_none() diff --git a/surfsense_backend/app/automations/persistence/__init__.py b/surfsense_backend/app/automations/persistence/__init__.py new file mode 100644 index 000000000..b10aef03d --- /dev/null +++ b/surfsense_backend/app/automations/persistence/__init__.py @@ -0,0 +1,15 @@ +"""Models and enums for the automation tables.""" + +from __future__ import annotations + +from .enums import AutomationStatus, RunStatus, TriggerType +from .models import Automation, AutomationRun, AutomationTrigger + +__all__ = [ + "Automation", + "AutomationRun", + "AutomationStatus", + "AutomationTrigger", + "RunStatus", + "TriggerType", +] diff --git a/surfsense_backend/app/automations/persistence/enums/__init__.py b/surfsense_backend/app/automations/persistence/enums/__init__.py new file mode 100644 index 000000000..6c2cfcf1f --- /dev/null +++ b/surfsense_backend/app/automations/persistence/enums/__init__.py @@ -0,0 +1,13 @@ +"""Enums for the automation tables.""" + +from __future__ import annotations + +from .automation_status import AutomationStatus +from .run_status import RunStatus +from .trigger_type import TriggerType + +__all__ = [ + "AutomationStatus", + "RunStatus", + "TriggerType", +] diff --git a/surfsense_backend/app/automations/persistence/enums/automation_status.py b/surfsense_backend/app/automations/persistence/enums/automation_status.py new file mode 100644 index 000000000..aff6f4683 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/enums/automation_status.py @@ -0,0 +1,11 @@ +"""Automation lifecycle status.""" + +from __future__ import annotations + +from enum import StrEnum + + +class AutomationStatus(StrEnum): + ACTIVE = "active" # eligible to fire + PAUSED = "paused" # kept, but triggers don't fire + ARCHIVED = "archived" # read-only history diff --git a/surfsense_backend/app/automations/persistence/enums/run_status.py b/surfsense_backend/app/automations/persistence/enums/run_status.py new file mode 100644 index 000000000..64dcd49e8 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/enums/run_status.py @@ -0,0 +1,14 @@ +"""AutomationRun state machine: pending → running → (succeeded|failed|cancelled|timed_out).""" + +from __future__ import annotations + +from enum import StrEnum + + +class RunStatus(StrEnum): + PENDING = "pending" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELLED = "cancelled" + TIMED_OUT = "timed_out" diff --git a/surfsense_backend/app/automations/persistence/enums/trigger_type.py b/surfsense_backend/app/automations/persistence/enums/trigger_type.py new file mode 100644 index 000000000..25bd784b1 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/enums/trigger_type.py @@ -0,0 +1,16 @@ +"""Trigger-kind discriminator. + +``schedule`` and ``event`` are registered. ``manual`` is reserved in the enum +(mirrors the postgres enum) but is intentionally unregistered pending a redesign +of the "Run now" UX. +""" + +from __future__ import annotations + +from enum import StrEnum + + +class TriggerType(StrEnum): + SCHEDULE = "schedule" + EVENT = "event" + MANUAL = "manual" diff --git a/surfsense_backend/app/automations/persistence/models/__init__.py b/surfsense_backend/app/automations/persistence/models/__init__.py new file mode 100644 index 000000000..8b985f025 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/models/__init__.py @@ -0,0 +1,13 @@ +"""Models, one per table.""" + +from __future__ import annotations + +from .automation import Automation +from .run import AutomationRun +from .trigger import AutomationTrigger + +__all__ = [ + "Automation", + "AutomationRun", + "AutomationTrigger", +] diff --git a/surfsense_backend/app/automations/persistence/models/automation.py b/surfsense_backend/app/automations/persistence/models/automation.py new file mode 100644 index 000000000..cb0b2ed31 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/models/automation.py @@ -0,0 +1,81 @@ +"""``automations`` table — editable, versioned automation definition.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import ( + TIMESTAMP, + Column, + Enum as SQLAlchemyEnum, + ForeignKey, + Integer, + String, + Text, +) +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import relationship + +from app.db import BaseModel, TimestampMixin + +from ..enums.automation_status import AutomationStatus + + +class Automation(BaseModel, TimestampMixin): + __tablename__ = "automations" + + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + created_by_user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + name = Column(String(200), nullable=False) + description = Column(Text, nullable=True) + + status = Column( + SQLAlchemyEnum( + AutomationStatus, + name="automation_status", + values_callable=lambda x: [e.value for e in x], + ), + nullable=False, + default=AutomationStatus.ACTIVE, + server_default=AutomationStatus.ACTIVE.value, + index=True, + ) + + definition = Column(JSONB, nullable=False) + + version = Column(Integer, nullable=False, default=1, server_default="1") + + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + index=True, + ) + + search_space = relationship("SearchSpace", back_populates="automations") + created_by = relationship("User", back_populates="automations") + triggers = relationship( + "AutomationTrigger", + back_populates="automation", + cascade="all, delete-orphan", + passive_deletes=True, + ) + runs = relationship( + "AutomationRun", + back_populates="automation", + cascade="all, delete-orphan", + passive_deletes=True, + ) diff --git a/surfsense_backend/app/automations/persistence/models/run.py b/surfsense_backend/app/automations/persistence/models/run.py new file mode 100644 index 000000000..471b2df77 --- /dev/null +++ b/surfsense_backend/app/automations/persistence/models/run.py @@ -0,0 +1,66 @@ +"""``automation_runs`` table — immutable per-fire execution record.""" + +from __future__ import annotations + +from sqlalchemy import ( + TIMESTAMP, + Column, + Enum as SQLAlchemyEnum, + ForeignKey, + Integer, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship + +from app.db import BaseModel, TimestampMixin + +from ..enums.run_status import RunStatus + + +class AutomationRun(BaseModel, TimestampMixin): + __tablename__ = "automation_runs" + + automation_id = Column( + Integer, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + trigger_id = Column( + Integer, + ForeignKey("automation_triggers.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + status = Column( + SQLAlchemyEnum( + RunStatus, + name="automation_run_status", + values_callable=lambda x: [e.value for e in x], + ), + nullable=False, + default=RunStatus.PENDING, + server_default=RunStatus.PENDING.value, + index=True, + ) + + # locked at fire time so historical runs always show the exact code path + definition_snapshot = Column(JSONB, nullable=False) + + # merged & validated inputs the run was dispatched with + # (trigger.static_inputs union producer runtime data, static wins on collision) + inputs = Column(JSONB, nullable=False, server_default="{}") + # one entry per executed step; agent_task entries carry their own + # `agent_session_id` inside their entry + step_results = Column(JSONB, nullable=False, server_default="[]") + output = Column(JSONB, nullable=True) + artifacts = Column(JSONB, nullable=False, server_default="[]") + error = Column(JSONB, nullable=True) + + started_at = Column(TIMESTAMP(timezone=True), nullable=True) + finished_at = Column(TIMESTAMP(timezone=True), nullable=True) + + automation = relationship("Automation", back_populates="runs") + trigger = relationship("AutomationTrigger", back_populates="runs") diff --git a/surfsense_backend/app/automations/persistence/models/trigger.py b/surfsense_backend/app/automations/persistence/models/trigger.py new file mode 100644 index 000000000..de1078acf --- /dev/null +++ b/surfsense_backend/app/automations/persistence/models/trigger.py @@ -0,0 +1,67 @@ +"""``automation_triggers`` table — one row per (automation, trigger-instance) pair.""" + +from __future__ import annotations + +from sqlalchemy import ( + TIMESTAMP, + Boolean, + Column, + Enum as SQLAlchemyEnum, + ForeignKey, + Integer, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship + +from app.db import BaseModel, TimestampMixin + +from ..enums.trigger_type import TriggerType + + +class AutomationTrigger(BaseModel, TimestampMixin): + __tablename__ = "automation_triggers" + + automation_id = Column( + Integer, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + type = Column( + SQLAlchemyEnum( + TriggerType, + name="automation_trigger_type", + values_callable=lambda x: [e.value for e in x], + ), + nullable=False, + index=True, + ) + + params = Column(JSONB, nullable=False) + + # Per-attachment domain values merged into every dispatched run's inputs. + # Static wins over runtime data on key collision. + static_inputs = Column(JSONB, nullable=False, server_default="{}") + + enabled = Column( + Boolean, + nullable=False, + default=True, + server_default="true", + index=True, + ) + + last_fired_at = Column(TIMESTAMP(timezone=True), nullable=True) + + # Precomputed next fire moment in UTC; advanced after each fire by the + # schedule tick. NULL means the trigger has never been scheduled (the + # tick self-heals on first sight). + next_fire_at = Column(TIMESTAMP(timezone=True), nullable=True) + + automation = relationship("Automation", back_populates="triggers") + runs = relationship( + "AutomationRun", + back_populates="trigger", + passive_deletes=True, + ) diff --git a/surfsense_backend/app/automations/runtime/__init__.py b/surfsense_backend/app/automations/runtime/__init__.py new file mode 100644 index 000000000..0650882b2 --- /dev/null +++ b/surfsense_backend/app/automations/runtime/__init__.py @@ -0,0 +1,7 @@ +"""Automation run executor: plan walker, step dispatch, retries, persistence.""" + +from __future__ import annotations + +from .executor import execute_run + +__all__ = ["execute_run"] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py new file mode 100644 index 000000000..da249d8e5 --- /dev/null +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -0,0 +1,140 @@ +"""Walk an ``AutomationRun``'s snapshot plan to terminal state.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.actions.types import ActionContext +from app.automations.persistence.enums.run_status import RunStatus +from app.automations.persistence.models.run import AutomationRun +from app.automations.schemas.definition.envelope import ( + AutomationDefinition, + AutomationModels, +) +from app.automations.schemas.definition.plan_step import PlanStep +from app.automations.templating import build_run_context + +from . import repository +from .step import execute_step + + +async def execute_run(session: AsyncSession, run_id: int) -> None: + """Load run ``run_id`` and execute its snapshot plan to a terminal state.""" + run = await repository.load_run(session, run_id) + if run is None: + raise ValueError(f"automation_run {run_id} not found") + + if run.status != RunStatus.PENDING: + return + + try: + definition = AutomationDefinition.model_validate(run.definition_snapshot) + except Exception as exc: + await repository.mark_failed( + session, + run, + { + "message": f"definition_snapshot invalid: {exc}", + "type": type(exc).__name__, + }, + ) + await session.commit() + return + + await repository.mark_running(session, run) + await session.commit() + + step_outputs: dict[str, Any] = {} + + for step in definition.plan: + template_ctx = _build_template_ctx(run, step_outputs) + action_ctx = _build_action_ctx(session, run, step, definition.models) + result = await execute_step( + step=step, + template_context=template_ctx, + action_context=action_ctx, + default_max_retries=definition.execution.max_retries, + default_retry_backoff=definition.execution.retry_backoff, + default_timeout_seconds=definition.execution.timeout_seconds, + ) + await repository.append_step_result(session, run, result) + await session.commit() + + if result["status"] == "failed": + await _run_on_failure(session, run, definition) + await repository.mark_failed(session, run, result.get("error")) + await session.commit() + return + + if result["status"] == "succeeded": + step_outputs[step.output_as or step.step_id] = result.get("result") + + await repository.mark_succeeded(session, run) + await session.commit() + + +async def _run_on_failure( + session: AsyncSession, + run: AutomationRun, + definition: AutomationDefinition, +) -> None: + """Run the on_failure steps. Their failures don't recurse into more on_failure.""" + if not definition.execution.on_failure: + return + template_ctx = _build_template_ctx(run, step_outputs={}) + for step in definition.execution.on_failure: + action_ctx = _build_action_ctx(session, run, step, definition.models) + result = await execute_step( + step=step, + template_context=template_ctx, + action_context=action_ctx, + default_max_retries=definition.execution.max_retries, + default_retry_backoff=definition.execution.retry_backoff, + default_timeout_seconds=definition.execution.timeout_seconds, + ) + await repository.append_step_result(session, run, result) + await session.commit() + + +def _build_template_ctx( + run: AutomationRun, step_outputs: dict[str, Any] +) -> dict[str, Any]: + automation = run.automation + trigger = run.trigger + return build_run_context( + run_id=run.id, + automation_id=run.automation_id, + automation_name=automation.name if automation else None, + automation_version=automation.version if automation else None, + search_space_id=automation.search_space_id if automation else None, + creator_id=automation.created_by_user_id if automation else None, + trigger_id=run.trigger_id, + trigger_type=trigger.type.value if trigger else None, + started_at=run.started_at, + attempt=1, + inputs=run.inputs or {}, + step_outputs=step_outputs, + ) + + +def _build_action_ctx( + session: AsyncSession, + run: AutomationRun, + step: PlanStep, + models: AutomationModels | None, +) -> ActionContext: + automation = run.automation + return ActionContext( + session=session, + run_id=run.id, + step_id=step.step_id, + search_space_id=automation.search_space_id, + creator_user_id=automation.created_by_user_id, + agent_llm_id=models.agent_llm_id if models else None, + image_generation_config_id=( + models.image_generation_config_id if models else None + ), + vision_llm_config_id=models.vision_llm_config_id if models else None, + ) diff --git a/surfsense_backend/app/automations/runtime/repository.py b/surfsense_backend/app/automations/runtime/repository.py new file mode 100644 index 000000000..a8bdbc55a --- /dev/null +++ b/surfsense_backend/app/automations/runtime/repository.py @@ -0,0 +1,62 @@ +"""Persistence operations on ``AutomationRun``. Pure SQL, no business logic.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.automations.persistence.enums.run_status import RunStatus +from app.automations.persistence.models.run import AutomationRun + + +async def load_run(session: AsyncSession, run_id: int) -> AutomationRun | None: + """Load a run with its automation and trigger eagerly loaded.""" + stmt = ( + select(AutomationRun) + .where(AutomationRun.id == run_id) + .options( + selectinload(AutomationRun.automation), + selectinload(AutomationRun.trigger), + ) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + +async def mark_running(session: AsyncSession, run: AutomationRun) -> None: + run.status = RunStatus.RUNNING + run.started_at = datetime.now(UTC) + await session.flush() + + +async def mark_succeeded(session: AsyncSession, run: AutomationRun) -> None: + run.status = RunStatus.SUCCEEDED + run.finished_at = datetime.now(UTC) + await session.flush() + + +async def mark_failed( + session: AsyncSession, + run: AutomationRun, + error: dict[str, Any] | None, +) -> None: + run.status = RunStatus.FAILED + run.finished_at = datetime.now(UTC) + run.error = error + await session.flush() + + +async def append_step_result( + session: AsyncSession, + run: AutomationRun, + step_result: dict[str, Any], +) -> None: + """Append one step result. Reassigns the list so SQLAlchemy detects the change.""" + current = list(run.step_results or []) + current.append(step_result) + run.step_results = current + await session.flush() diff --git a/surfsense_backend/app/automations/runtime/retries.py b/surfsense_backend/app/automations/runtime/retries.py new file mode 100644 index 000000000..d5bfb15ca --- /dev/null +++ b/surfsense_backend/app/automations/runtime/retries.py @@ -0,0 +1,36 @@ +"""Retry policy enforcement for action handlers.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable + + +async def with_retries[T]( + coro_factory: Callable[[], Awaitable[T]], + *, + max_retries: int, + backoff: str, + timeout: int | None, +) -> tuple[T, int]: + """Call ``coro_factory`` up to ``1 + max_retries`` times. Return ``(result, attempts)``.""" + total = 1 + max(0, max_retries) + for attempt in range(1, total + 1): + try: + coro = coro_factory() + if timeout is not None and timeout > 0: + return await asyncio.wait_for(coro, timeout=timeout), attempt + return await coro, attempt + except Exception: + if attempt >= total: + raise + await asyncio.sleep(_backoff_seconds(backoff, attempt)) + raise RuntimeError("with_retries exhausted without raising or returning") + + +def _backoff_seconds(strategy: str, attempt: int) -> float: + if strategy == "exponential": + return float(2 ** (attempt - 1)) + if strategy == "linear": + return float(attempt) + return 0.0 diff --git a/surfsense_backend/app/automations/runtime/step.py b/surfsense_backend/app/automations/runtime/step.py new file mode 100644 index 000000000..6e7c9c671 --- /dev/null +++ b/surfsense_backend/app/automations/runtime/step.py @@ -0,0 +1,107 @@ +"""Execute one plan step: when-predicate, params render, handler dispatch, retries.""" + +from __future__ import annotations + +from collections.abc import Mapping +from datetime import UTC, datetime +from typing import Any + +from app.automations.actions import get_action +from app.automations.actions.types import ActionContext +from app.automations.schemas.definition.plan_step import PlanStep +from app.automations.templating import evaluate_predicate, render_value + +from .retries import with_retries + + +async def execute_step( + *, + step: PlanStep, + template_context: Mapping[str, Any], + action_context: ActionContext, + default_max_retries: int, + default_retry_backoff: str, + default_timeout_seconds: int, +) -> dict[str, Any]: + """Run one step and return its structured result entry.""" + started_at = datetime.now(UTC) + + if step.when is not None: + try: + should_run = evaluate_predicate(step.when, template_context) + except Exception as exc: + return _result( + step, "failed", started_at, attempts=0, error=_error(exc, "when") + ) + if not should_run: + return _result(step, "skipped", started_at, attempts=0) + + try: + resolved_params = render_value(step.params, template_context) + except Exception as exc: + return _result( + step, "failed", started_at, attempts=0, error=_error(exc, "render") + ) + + action = get_action(step.action) + if action is None: + return _result( + step, + "failed", + started_at, + attempts=0, + error={ + "message": f"action not registered: {step.action}", + "type": "ActionNotFound", + }, + ) + + handler = action.build_handler(action_context) + + max_retries = ( + step.max_retries if step.max_retries is not None else default_max_retries + ) + timeout = step.timeout_seconds or default_timeout_seconds + + try: + result, attempts = await with_retries( + lambda: handler(resolved_params), + max_retries=max_retries, + backoff=default_retry_backoff, + timeout=timeout, + ) + except Exception as exc: + return _result( + step, "failed", started_at, attempts=max_retries + 1, error=_error(exc) + ) + + return _result(step, "succeeded", started_at, attempts=attempts, result=result) + + +def _result( + step: PlanStep, + status: str, + started_at: datetime, + *, + attempts: int, + result: Any = None, + error: dict[str, Any] | None = None, +) -> dict[str, Any]: + entry: dict[str, Any] = { + "step_id": step.step_id, + "action": step.action, + "status": status, + "started_at": started_at.isoformat(), + "finished_at": datetime.now(UTC).isoformat(), + "attempts": attempts, + } + if result is not None: + entry["result"] = result + if error is not None: + entry["error"] = error + return entry + + +def _error(exc: Exception, phase: str | None = None) -> dict[str, Any]: + msg = f"{phase}: {exc}" if phase else str(exc) + return {"message": msg, "type": type(exc).__name__} diff --git a/surfsense_backend/app/automations/schemas/__init__.py b/surfsense_backend/app/automations/schemas/__init__.py new file mode 100644 index 000000000..2e2d60f12 --- /dev/null +++ b/surfsense_backend/app/automations/schemas/__init__.py @@ -0,0 +1,27 @@ +"""Schemas for the automation definition envelope. + +Per-action and per-trigger params schemas live with the action/trigger +implementations (``app.automations.actions..params`` / +``app.automations.triggers..params``); only the cross-cutting envelope +lives here. +""" + +from __future__ import annotations + +from .definition import ( + AutomationDefinition, + Execution, + Inputs, + Metadata, + PlanStep, + TriggerSpec, +) + +__all__ = [ + "AutomationDefinition", + "Execution", + "Inputs", + "Metadata", + "PlanStep", + "TriggerSpec", +] diff --git a/surfsense_backend/app/automations/schemas/api/__init__.py b/surfsense_backend/app/automations/schemas/api/__init__.py new file mode 100644 index 000000000..f49e5c589 --- /dev/null +++ b/surfsense_backend/app/automations/schemas/api/__init__.py @@ -0,0 +1,27 @@ +"""Request/response schemas for the automations HTTP layer.""" + +from __future__ import annotations + +from .automation import ( + AutomationCreate, + AutomationDetail, + AutomationList, + AutomationSummary, + AutomationUpdate, +) +from .run import RunDetail, RunList, RunSummary +from .trigger import TriggerCreate, TriggerDetail, TriggerUpdate + +__all__ = [ + "AutomationCreate", + "AutomationDetail", + "AutomationList", + "AutomationSummary", + "AutomationUpdate", + "RunDetail", + "RunList", + "RunSummary", + "TriggerCreate", + "TriggerDetail", + "TriggerUpdate", +] diff --git a/surfsense_backend/app/automations/schemas/api/automation.py b/surfsense_backend/app/automations/schemas/api/automation.py new file mode 100644 index 000000000..c1defd417 --- /dev/null +++ b/surfsense_backend/app/automations/schemas/api/automation.py @@ -0,0 +1,64 @@ +"""Request/response schemas for the ``Automation`` resource.""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + +from app.automations.persistence.enums.automation_status import AutomationStatus +from app.automations.schemas.definition import AutomationDefinition + +from .trigger import TriggerCreate, TriggerDetail + + +class AutomationCreate(BaseModel): + """Create an automation, optionally with initial triggers (atomic).""" + + model_config = ConfigDict(extra="forbid") + + search_space_id: int + name: str = Field(..., min_length=1, max_length=200) + description: str | None = None + definition: AutomationDefinition + triggers: list[TriggerCreate] = Field(default_factory=list) + + +class AutomationUpdate(BaseModel): + """Partial update of an automation. Triggers are managed separately.""" + + model_config = ConfigDict(extra="forbid") + + name: str | None = Field(default=None, min_length=1, max_length=200) + description: str | None = None + status: AutomationStatus | None = None + definition: AutomationDefinition | None = None + + +class AutomationSummary(BaseModel): + """Lightweight automation view for list endpoints.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + search_space_id: int + name: str + description: str | None = None + status: AutomationStatus + version: int + created_at: datetime + updated_at: datetime + + +class AutomationDetail(AutomationSummary): + """Full automation view including definition and attached triggers.""" + + definition: AutomationDefinition + triggers: list[TriggerDetail] = Field(default_factory=list) + + +class AutomationList(BaseModel): + """Paginated list of automations.""" + + items: list[AutomationSummary] + total: int diff --git a/surfsense_backend/app/automations/schemas/api/run.py b/surfsense_backend/app/automations/schemas/api/run.py new file mode 100644 index 000000000..3f6eaab82 --- /dev/null +++ b/surfsense_backend/app/automations/schemas/api/run.py @@ -0,0 +1,42 @@ +"""Response schemas for run sub-resources.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from app.automations.persistence.enums.run_status import RunStatus + + +class RunSummary(BaseModel): + """Lightweight run view for list endpoints.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + automation_id: int + trigger_id: int | None = None + status: RunStatus + started_at: datetime | None = None + finished_at: datetime | None = None + created_at: datetime + + +class RunDetail(RunSummary): + """Full run view including snapshot, results and artifacts.""" + + definition_snapshot: dict[str, Any] + inputs: dict[str, Any] + step_results: list[dict[str, Any]] + output: dict[str, Any] | None = None + artifacts: list[dict[str, Any]] + error: dict[str, Any] | None = None + + +class RunList(BaseModel): + """Paginated list of runs.""" + + items: list[RunSummary] + total: int diff --git a/surfsense_backend/app/automations/schemas/api/trigger.py b/surfsense_backend/app/automations/schemas/api/trigger.py new file mode 100644 index 000000000..35176fb9f --- /dev/null +++ b/surfsense_backend/app/automations/schemas/api/trigger.py @@ -0,0 +1,46 @@ +"""Request/response schemas for trigger sub-resources.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.automations.persistence.enums.trigger_type import TriggerType + + +class TriggerCreate(BaseModel): + """Attach a trigger to an automation.""" + + model_config = ConfigDict(extra="forbid") + + type: TriggerType + params: dict[str, Any] = Field(default_factory=dict) + static_inputs: dict[str, Any] = Field(default_factory=dict) + enabled: bool = True + + +class TriggerUpdate(BaseModel): + """Partial update of an existing trigger.""" + + model_config = ConfigDict(extra="forbid") + + enabled: bool | None = None + params: dict[str, Any] | None = None + static_inputs: dict[str, Any] | None = None + + +class TriggerDetail(BaseModel): + """Trigger as returned to clients.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + type: TriggerType + params: dict[str, Any] + static_inputs: dict[str, Any] + enabled: bool + last_fired_at: datetime | None = None + next_fire_at: datetime | None = None + created_at: datetime diff --git a/surfsense_backend/app/automations/schemas/definition/__init__.py b/surfsense_backend/app/automations/schemas/definition/__init__.py new file mode 100644 index 000000000..72404264e --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/__init__.py @@ -0,0 +1,20 @@ +"""Automation definition envelope and its components.""" + +from __future__ import annotations + +from .envelope import AutomationDefinition, AutomationModels +from .execution import Execution +from .inputs import Inputs +from .metadata import Metadata +from .plan_step import PlanStep +from .trigger_spec import TriggerSpec + +__all__ = [ + "AutomationDefinition", + "AutomationModels", + "Execution", + "Inputs", + "Metadata", + "PlanStep", + "TriggerSpec", +] diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py new file mode 100644 index 000000000..7ca55b1ce --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -0,0 +1,45 @@ +"""``AutomationDefinition`` — top-level envelope persisted in ``automations.definition``.""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + +from .execution import Execution +from .inputs import Inputs +from .metadata import Metadata +from .plan_step import PlanStep +from .trigger_spec import TriggerSpec + + +class AutomationModels(BaseModel): + """Captured model profile for an automation. + + Snapshotted from the search space's preferences at create time so runs are + insulated from later chat/search-space model changes. Config-id conventions + match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). + """ + + model_config = ConfigDict(extra="forbid") + + agent_llm_id: int = 0 + image_generation_config_id: int = 0 + vision_llm_config_id: int = 0 + + +class AutomationDefinition(BaseModel): + """Top-level shape of an automation.""" + + model_config = ConfigDict(extra="forbid") + + schema_version: str = "1.0" + name: str = Field(..., min_length=1, max_length=200) + goal: str | None = None + inputs: Inputs | None = None + triggers: list[TriggerSpec] = Field(default_factory=list) + plan: list[PlanStep] = Field(..., min_length=1) + execution: Execution = Field(default_factory=Execution) + metadata: Metadata = Field(default_factory=Metadata) + # Captured server-side at create() and preserved across update(); resolved + # at runtime instead of the live search space. Optional so drafts/builder + # payloads validate without it. + models: AutomationModels | None = None diff --git a/surfsense_backend/app/automations/schemas/definition/execution.py b/surfsense_backend/app/automations/schemas/definition/execution.py new file mode 100644 index 000000000..bdbad62f8 --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/execution.py @@ -0,0 +1,24 @@ +"""``Execution`` — automation-wide execution defaults (overridable per step).""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +from .plan_step import PlanStep + + +class Execution(BaseModel): + model_config = ConfigDict(extra="forbid") + + timeout_seconds: int = Field( + default=600, gt=0, description="Wall-clock cap for the run." + ) + max_retries: int = Field(default=2, ge=0, description="Per-step retry budget.") + retry_backoff: Literal["exponential", "linear", "none"] = "exponential" + concurrency: Literal["drop_if_running", "queue", "always"] = "drop_if_running" + on_failure: list[PlanStep] = Field( + default_factory=list, + description="Steps run when the main plan fails after retries.", + ) diff --git a/surfsense_backend/app/automations/schemas/definition/inputs.py b/surfsense_backend/app/automations/schemas/definition/inputs.py new file mode 100644 index 000000000..619fd16cd --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/inputs.py @@ -0,0 +1,21 @@ +"""``Inputs`` — JSON Schema for inputs an automation accepts at fire time.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class Inputs(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + serialize_by_alias=True, + ) + + schema_: dict[str, Any] = Field( + ..., + alias="schema", + description="JSON Schema (draft 2020-12) for accepted inputs.", + ) diff --git a/surfsense_backend/app/automations/schemas/definition/metadata.py b/surfsense_backend/app/automations/schemas/definition/metadata.py new file mode 100644 index 000000000..3ac341d2e --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/metadata.py @@ -0,0 +1,11 @@ +"""``Metadata`` — free-form metadata on a definition. Extra keys allowed.""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + + +class Metadata(BaseModel): + model_config = ConfigDict(extra="allow") + + tags: list[str] = Field(default_factory=list) diff --git a/surfsense_backend/app/automations/schemas/definition/plan_step.py b/surfsense_backend/app/automations/schemas/definition/plan_step.py new file mode 100644 index 000000000..0d3bb9dfc --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/plan_step.py @@ -0,0 +1,30 @@ +"""``PlanStep`` — one step in the sequential plan.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class PlanStep(BaseModel): + model_config = ConfigDict(extra="forbid") + + step_id: str = Field(..., min_length=1, description="Unique within the plan.") + action: str = Field( + ..., min_length=1, description="Action type; resolved via registry." + ) + when: str | None = Field( + default=None, + description="Optional predicate; step is skipped when falsy.", + ) + params: dict[str, Any] = Field( + default_factory=dict, + description="Action-type-specific params; rendered at execute time.", + ) + output_as: str | None = Field( + default=None, + description="Bind step output under this name. Defaults to step_id.", + ) + max_retries: int | None = Field(default=None, ge=0) + timeout_seconds: int | None = Field(default=None, gt=0) diff --git a/surfsense_backend/app/automations/schemas/definition/trigger_spec.py b/surfsense_backend/app/automations/schemas/definition/trigger_spec.py new file mode 100644 index 000000000..e6a995bbf --- /dev/null +++ b/surfsense_backend/app/automations/schemas/definition/trigger_spec.py @@ -0,0 +1,19 @@ +"""``TriggerSpec`` — one entry in the definition's ``triggers[]`` array.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class TriggerSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: str = Field( + ..., min_length=1, description="Trigger type; resolved via registry." + ) + params: dict[str, Any] = Field( + default_factory=dict, + description="Type-specific params; validated against the trigger's schema.", + ) diff --git a/surfsense_backend/app/automations/services/__init__.py b/surfsense_backend/app/automations/services/__init__.py new file mode 100644 index 000000000..904a3413a --- /dev/null +++ b/surfsense_backend/app/automations/services/__init__.py @@ -0,0 +1,28 @@ +"""Services for the automations HTTP layer (one service per resource).""" + +from __future__ import annotations + +from .automation import AutomationService, get_automation_service +from .model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, + get_automation_model_eligibility, + get_model_eligibility, +) +from .run import RunService, get_run_service +from .trigger import TriggerService, get_trigger_service + +__all__ = [ + "AutomationModelPolicyError", + "AutomationService", + "RunService", + "TriggerService", + "assert_automation_models_billable", + "assert_models_billable", + "get_automation_model_eligibility", + "get_automation_service", + "get_model_eligibility", + "get_run_service", + "get_trigger_service", +] diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py new file mode 100644 index 000000000..4227161e2 --- /dev/null +++ b/surfsense_backend/app/automations/services/automation.py @@ -0,0 +1,279 @@ +"""``AutomationService`` — orchestration for the ``Automation`` resource.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from fastapi import Depends, HTTPException +from pydantic import ValidationError +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.persistence.models.automation import Automation +from app.automations.persistence.models.trigger import AutomationTrigger +from app.automations.schemas.api import ( + AutomationCreate, + AutomationUpdate, + TriggerCreate, +) +from app.automations.schemas.definition.envelope import AutomationModels +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, + get_automation_model_eligibility, +) +from app.automations.triggers import get_trigger +from app.automations.triggers.builtin.schedule import compute_next_fire_at +from app.db import Permission, SearchSpace, User, get_async_session +from app.users import current_active_user +from app.utils.rbac import check_permission + + +class AutomationService: + """Lifecycle of the ``Automation`` resource.""" + + def __init__(self, *, session: AsyncSession, user: User) -> None: + self.session = session + self.user = user + + async def create(self, payload: AutomationCreate) -> Automation: + """Create an automation and its initial triggers in one transaction.""" + await self._authorize( + payload.search_space_id, Permission.AUTOMATIONS_CREATE.value + ) + + # Capture the model profile onto the definition so runs are insulated + # from later chat/search-space model changes. Two sources: + # 1. Explicit per-automation selection in ``payload.definition.models`` + # (manual builder + chat approval card). Validate the chosen ids. + # 2. Fallback (no selection): snapshot the search space's current prefs. + # Either way the captured ids are guaranteed billable (premium/BYOK). + selected_models = payload.definition.models + if selected_models is not None: + self._assert_selected_models_billable(selected_models) + else: + search_space = await self._assert_models_billable(payload.search_space_id) + payload.definition.models = AutomationModels( + agent_llm_id=search_space.agent_llm_id or 0, + image_generation_config_id=search_space.image_generation_config_id or 0, + vision_llm_config_id=search_space.vision_llm_config_id or 0, + ) + + automation = Automation( + search_space_id=payload.search_space_id, + created_by_user_id=self.user.id, + name=payload.name, + description=payload.description, + definition=payload.definition.model_dump(mode="json", by_alias=True), + version=1, + ) + for spec in payload.triggers: + automation.triggers.append(_build_trigger(spec)) + + self.session.add(automation) + await self.session.commit() + return await self._get_with_triggers_or_raise(automation.id) + + async def list( + self, + *, + search_space_id: int, + limit: int, + offset: int, + ) -> tuple[list[Automation], int]: + """Return a page of automations and the total count.""" + await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value) + + base = select(Automation).where(Automation.search_space_id == search_space_id) + total = await self.session.scalar( + select(func.count()).select_from(base.subquery()) + ) + + rows = ( + ( + await self.session.execute( + base.order_by(Automation.created_at.desc()) + .limit(limit) + .offset(offset) + ) + ) + .scalars() + .all() + ) + return list(rows), int(total or 0) + + async def get(self, automation_id: int) -> Automation: + """Get an automation with its triggers loaded.""" + automation = await self._get_with_triggers_or_raise(automation_id) + await self._authorize( + automation.search_space_id, Permission.AUTOMATIONS_READ.value + ) + return automation + + async def update(self, automation_id: int, patch: AutomationUpdate) -> Automation: + """Patch fields. Bumps ``version`` when ``definition`` changes.""" + automation = await self._get_with_triggers_or_raise(automation_id) + await self._authorize( + automation.search_space_id, Permission.AUTOMATIONS_UPDATE.value + ) + + data = patch.model_dump(exclude_unset=True) + + if "name" in data: + automation.name = data["name"] + if "description" in data: + automation.description = data["description"] + if "status" in data: + automation.status = data["status"] + if "definition" in data: + new_def = patch.definition.model_dump(mode="json", by_alias=True) + # Model snapshot handling on edit: + # * absent in the patch -> preserve the captured snapshot + # (a non-model definition change never silently re-binds the + # automation to the current chat/search-space selection). + # * unchanged from the snapshot -> keep as-is, no re-validation + # (so editing an automation whose captured model later drifted + # out of premium isn't blocked by an unrelated name/schedule edit). + # * genuinely changed -> validate the new selection (422 on a + # non-billable pick), then accept it. + existing_models = (automation.definition or {}).get("models") + provided_models = new_def.get("models") + if provided_models is None: + if existing_models is not None: + new_def["models"] = existing_models + elif provided_models != existing_models: + self._assert_selected_models_billable(patch.definition.models) + automation.definition = new_def + automation.version += 1 + + await self.session.commit() + return await self._get_with_triggers_or_raise(automation_id) + + async def delete(self, automation_id: int) -> None: + """Delete an automation; FK cascades remove triggers and runs.""" + automation = await self._get_or_raise(automation_id) + await self._authorize( + automation.search_space_id, Permission.AUTOMATIONS_DELETE.value + ) + await self.session.delete(automation) + await self.session.commit() + + async def _get_or_raise(self, automation_id: int) -> Automation: + automation = await self.session.get(Automation, automation_id) + if automation is None: + raise HTTPException( + status_code=404, detail=f"automation {automation_id} not found" + ) + return automation + + async def _get_with_triggers_or_raise(self, automation_id: int) -> Automation: + stmt = ( + select(Automation) + .where(Automation.id == automation_id) + .options(selectinload(Automation.triggers)) + ) + automation = (await self.session.execute(stmt)).scalar_one_or_none() + if automation is None: + raise HTTPException( + status_code=404, detail=f"automation {automation_id} not found" + ) + return automation + + async def model_eligibility(self, *, search_space_id: int) -> dict: + """Return whether a search space's models are billable for automations. + + ``{"allowed": bool, "violations": [{kind, config_id, reason}, ...]}``. + """ + await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value) + search_space = await self.session.get(SearchSpace, search_space_id) + if search_space is None: + raise HTTPException( + status_code=404, detail=f"search space {search_space_id} not found" + ) + return get_automation_model_eligibility(search_space) + + async def _assert_models_billable(self, search_space_id: int) -> SearchSpace: + """Reject creation when the search space's models aren't billable. + + Automations may only use premium global models or user BYOK models; free + global models and Auto mode are blocked. Mirrors the runtime backstop in + ``agent_task`` so users can't save an automation that would fail to run. + + Returns the loaded :class:`SearchSpace` so the caller can capture its + model prefs without a second DB read. + """ + search_space = await self.session.get(SearchSpace, search_space_id) + if search_space is None: + raise HTTPException( + status_code=404, detail=f"search space {search_space_id} not found" + ) + try: + assert_automation_models_billable(search_space) + except AutomationModelPolicyError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + return search_space + + def _assert_selected_models_billable(self, models: AutomationModels) -> None: + """Reject creation when an explicitly selected model isn't billable. + + Used when the client supplies ``definition.models`` (per-automation + selection from the builder or chat approval card). Same policy as the + search-space path: premium global or BYOK only, no free/Auto. + """ + try: + assert_models_billable( + agent_llm_id=models.agent_llm_id, + image_generation_config_id=models.image_generation_config_id, + vision_llm_config_id=models.vision_llm_config_id, + ) + except AutomationModelPolicyError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + + async def _authorize(self, search_space_id: int, permission: str) -> None: + await check_permission( + self.session, + self.user, + search_space_id, + permission, + f"You don't have permission to {permission.split(':')[1]} automations in this search space", + ) + + +def _build_trigger(spec: TriggerCreate) -> AutomationTrigger: + """Validate trigger params via its registered Pydantic model and build the ORM row.""" + definition = get_trigger(spec.type.value) + if definition is None: + raise HTTPException( + status_code=422, detail=f"unknown trigger type {spec.type.value!r}" + ) + + try: + validated = definition.params_model.model_validate(spec.params) + except ValidationError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + + params = validated.model_dump(mode="json") + + next_fire_at = None + if spec.type == TriggerType.SCHEDULE and spec.enabled: + next_fire_at = compute_next_fire_at( + params["cron"], params["timezone"], after=datetime.now(UTC) + ) + + return AutomationTrigger( + type=spec.type, + params=params, + static_inputs=spec.static_inputs, + enabled=spec.enabled, + next_fire_at=next_fire_at, + ) + + +def get_automation_service( + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AutomationService: + return AutomationService(session=session, user=user) diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py new file mode 100644 index 000000000..88e9d5f28 --- /dev/null +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -0,0 +1,173 @@ +"""Model-billing policy for automations. + +Automations run unattended, so every run must be **billable**: it may only use +either a premium global model (``billing_tier == "premium"``) or a user-provided +BYOK model (a positive config id pointing at a per-user/per-space DB row). Free +global models and Auto mode are blocked, because Auto can dispatch to a free +deployment and free models aren't metered in premium credits. + +Config id conventions (shared across chat / image / vision): +- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / + ``VISION_AUTO_MODE_ID``). Blocked. +- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. +- ``id > 0`` → user BYOK DB row. Always allowed. + +This module is the single source of truth used by both creation-time enforcement +(``AutomationService.create`` and the ``create_automation`` chat tool) and the +runtime backstop (``agent_task`` dependencies). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from app.db import SearchSpace + +ModelKind = Literal["llm", "image", "vision"] + +_KIND_LABEL: dict[ModelKind, str] = { + "llm": "agent LLM", + "image": "image generation model", + "vision": "vision model", +} + + +def _is_premium_global(kind: ModelKind, config_id: int) -> bool: + """Return True if a negative (global) config id is a premium tier model.""" + from app.config import config as app_config + + cfg: dict | None = None + if kind == "llm": + from app.agents.new_chat.llm_config import load_global_llm_config_by_id + + cfg = load_global_llm_config_by_id(config_id) + elif kind == "image": + cfg = next( + ( + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if c.get("id") == config_id + ), + None, + ) + else: # vision + cfg = next( + ( + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if c.get("id") == config_id + ), + None, + ) + + if not cfg: + return False + return str(cfg.get("billing_tier", "free")).lower() == "premium" + + +def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: + """Classify a resolved config id as allowed or blocked. + + Returns ``(allowed, reason)``; ``reason`` is empty when allowed. + """ + label = _KIND_LABEL[kind] + + if config_id is None or config_id == 0: + return ( + False, + f"The {label} is set to Auto mode. Automations require an explicit " + "premium model or your own (BYOK) model so every run is billable.", + ) + + if config_id > 0: + # Positive id → user-owned BYOK config. Always allowed. + return True, "" + + # Negative id → global config. Allowed only if premium. + if _is_premium_global(kind, config_id): + return True, "" + + return ( + False, + f"The {label} is a free model. Automations can only use premium models " + "or your own (BYOK) models so every run is billable.", + ) + + +def get_model_eligibility( + *, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, +) -> dict: + """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. + + The ID-based core shared by both the search-space path (creation/eligibility) + and the captured-snapshot path (runtime backstop). Each violation is + ``{"kind", "config_id", "reason"}``. + """ + checks: list[tuple[ModelKind, int | None]] = [ + ("llm", agent_llm_id), + ("image", image_generation_config_id), + ("vision", vision_llm_config_id), + ] + + violations: list[dict] = [] + for kind, config_id in checks: + allowed, reason = _classify(kind, config_id) + if not allowed: + violations.append({"kind": kind, "config_id": config_id, "reason": reason}) + + return {"allowed": not violations, "violations": violations} + + +def get_automation_model_eligibility(search_space: SearchSpace) -> dict: + """Return ``{"allowed": bool, "violations": [...]}`` for a search space. + + Used by the eligibility endpoint and the chat tool's early check. Thin + wrapper over :func:`get_model_eligibility`. + """ + return get_model_eligibility( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + ) + + +class AutomationModelPolicyError(Exception): + """Raised when a search space's models are not billable for automations.""" + + def __init__(self, violations: list[dict]) -> None: + self.violations = violations + reasons = "; ".join(v["reason"] for v in violations) + super().__init__( + reasons or "Automations require premium or BYOK models for all model slots." + ) + + +def assert_models_billable( + *, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, +) -> None: + """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. + + The ID-based core used by the runtime backstop against an automation's + captured model snapshot. + """ + result = get_model_eligibility( + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, + ) + if not result["allowed"]: + raise AutomationModelPolicyError(result["violations"]) + + +def assert_automation_models_billable(search_space: SearchSpace) -> None: + """Raise :class:`AutomationModelPolicyError` if any model slot is not billable.""" + result = get_automation_model_eligibility(search_space) + if not result["allowed"]: + raise AutomationModelPolicyError(result["violations"]) diff --git a/surfsense_backend/app/automations/services/run.py b/surfsense_backend/app/automations/services/run.py new file mode 100644 index 000000000..3ef80416f --- /dev/null +++ b/surfsense_backend/app/automations/services/run.py @@ -0,0 +1,78 @@ +"""``RunService`` — read-only access to automation run history.""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.persistence.models.automation import Automation +from app.automations.persistence.models.run import AutomationRun +from app.db import Permission, User, get_async_session +from app.users import current_active_user +from app.utils.rbac import check_permission + + +class RunService: + """Read-only access to ``AutomationRun`` history.""" + + def __init__(self, *, session: AsyncSession, user: User) -> None: + self.session = session + self.user = user + + async def list( + self, + *, + automation_id: int, + limit: int, + offset: int, + ) -> tuple[list[AutomationRun], int]: + """Return a page of runs for an automation, newest first.""" + await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value) + + base = select(AutomationRun).where(AutomationRun.automation_id == automation_id) + total = await self.session.scalar( + select(func.count()).select_from(base.subquery()) + ) + + rows = ( + ( + await self.session.execute( + base.order_by(AutomationRun.created_at.desc()) + .limit(limit) + .offset(offset) + ) + ) + .scalars() + .all() + ) + return list(rows), int(total or 0) + + async def get(self, *, automation_id: int, run_id: int) -> AutomationRun: + await self._authorize(automation_id, Permission.AUTOMATIONS_READ.value) + run = await self.session.get(AutomationRun, run_id) + if run is None or run.automation_id != automation_id: + raise HTTPException(status_code=404, detail=f"run {run_id} not found") + return run + + async def _authorize(self, automation_id: int, permission: str) -> Automation: + automation = await self.session.get(Automation, automation_id) + if automation is None: + raise HTTPException( + status_code=404, detail=f"automation {automation_id} not found" + ) + await check_permission( + self.session, + self.user, + automation.search_space_id, + permission, + f"You don't have permission to {permission.split(':')[1]} automations in this search space", + ) + return automation + + +def get_run_service( + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> RunService: + return RunService(session=session, user=user) diff --git a/surfsense_backend/app/automations/services/trigger.py b/surfsense_backend/app/automations/services/trigger.py new file mode 100644 index 000000000..523153927 --- /dev/null +++ b/surfsense_backend/app/automations/services/trigger.py @@ -0,0 +1,149 @@ +"""``TriggerService`` — lifecycle of triggers attached to an automation.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from fastapi import Depends, HTTPException +from pydantic import ValidationError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.persistence.models.automation import Automation +from app.automations.persistence.models.trigger import AutomationTrigger +from app.automations.schemas.api import TriggerCreate, TriggerUpdate +from app.automations.triggers import get_trigger +from app.automations.triggers.builtin.schedule import compute_next_fire_at +from app.db import Permission, User, get_async_session +from app.users import current_active_user +from app.utils.rbac import check_permission + + +class TriggerService: + """Lifecycle of the ``AutomationTrigger`` sub-resource.""" + + def __init__(self, *, session: AsyncSession, user: User) -> None: + self.session = session + self.user = user + + async def add( + self, *, automation_id: int, payload: TriggerCreate + ) -> AutomationTrigger: + automation = await self._authorize_automation( + automation_id, Permission.AUTOMATIONS_UPDATE.value + ) + + validated_params = _validate_params(payload.type, payload.params) + trigger = AutomationTrigger( + automation_id=automation.id, + type=payload.type, + params=validated_params, + static_inputs=payload.static_inputs, + enabled=payload.enabled, + next_fire_at=_initial_next_fire( + payload.type, validated_params, payload.enabled + ), + ) + self.session.add(trigger) + await self.session.commit() + await self.session.refresh(trigger) + return trigger + + async def update( + self, + *, + automation_id: int, + trigger_id: int, + patch: TriggerUpdate, + ) -> AutomationTrigger: + await self._authorize_automation( + automation_id, Permission.AUTOMATIONS_UPDATE.value + ) + trigger = await self._get_trigger_or_raise(automation_id, trigger_id) + + data = patch.model_dump(exclude_unset=True) + + if "params" in data: + trigger.params = _validate_params(trigger.type, data["params"]) + + if "static_inputs" in data: + trigger.static_inputs = data["static_inputs"] + + if "enabled" in data: + trigger.enabled = data["enabled"] + + # Recompute next_fire_at when schedule timing changed or the trigger was + # toggled back on. + if trigger.type == TriggerType.SCHEDULE: + trigger.next_fire_at = _initial_next_fire( + trigger.type, trigger.params, trigger.enabled + ) + + await self.session.commit() + await self.session.refresh(trigger) + return trigger + + async def remove(self, *, automation_id: int, trigger_id: int) -> None: + await self._authorize_automation( + automation_id, Permission.AUTOMATIONS_UPDATE.value + ) + trigger = await self._get_trigger_or_raise(automation_id, trigger_id) + await self.session.delete(trigger) + await self.session.commit() + + async def _authorize_automation( + self, automation_id: int, permission: str + ) -> Automation: + automation = await self.session.get(Automation, automation_id) + if automation is None: + raise HTTPException( + status_code=404, detail=f"automation {automation_id} not found" + ) + await check_permission( + self.session, + self.user, + automation.search_space_id, + permission, + f"You don't have permission to {permission.split(':')[1]} automations in this search space", + ) + return automation + + async def _get_trigger_or_raise( + self, automation_id: int, trigger_id: int + ) -> AutomationTrigger: + trigger = await self.session.get(AutomationTrigger, trigger_id) + if trigger is None or trigger.automation_id != automation_id: + raise HTTPException( + status_code=404, detail=f"trigger {trigger_id} not found" + ) + return trigger + + +def _validate_params(trigger_type: TriggerType, raw: dict) -> dict: + definition = get_trigger(trigger_type.value) + if definition is None: + raise HTTPException( + status_code=422, detail=f"unknown trigger type {trigger_type.value!r}" + ) + try: + validated = definition.params_model.model_validate(raw) + except ValidationError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + return validated.model_dump(mode="json") + + +def _initial_next_fire( + trigger_type: TriggerType, params: dict, enabled: bool +) -> datetime | None: + if trigger_type != TriggerType.SCHEDULE or not enabled: + return None + return compute_next_fire_at( + params["cron"], params["timezone"], after=datetime.now(UTC) + ) + + +def get_trigger_service( + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> TriggerService: + return TriggerService(session=session, user=user) diff --git a/surfsense_backend/app/automations/tasks/__init__.py b/surfsense_backend/app/automations/tasks/__init__.py new file mode 100644 index 000000000..6fe0d62e8 --- /dev/null +++ b/surfsense_backend/app/automations/tasks/__init__.py @@ -0,0 +1,3 @@ +"""Celery task wrappers for the automation runtime.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/automations/tasks/execute_run.py b/surfsense_backend/app/automations/tasks/execute_run.py new file mode 100644 index 000000000..ed448515d --- /dev/null +++ b/surfsense_backend/app/automations/tasks/execute_run.py @@ -0,0 +1,33 @@ +"""Celery task that runs one automation. Thin wrapper over ``runtime.executor``.""" + +from __future__ import annotations + +import logging + +from app.automations.runtime import execute_run +from app.celery_app import celery_app +from app.tasks.celery_tasks import ( + get_celery_session_maker, + run_async_celery_task, +) + +logger = logging.getLogger(__name__) + +TASK_NAME = "automation_run_execute" + + +@celery_app.task(name=TASK_NAME, bind=True) +def automation_run_execute(self, run_id: int) -> None: + """Execute one ``AutomationRun``. Idempotent: terminal runs no-op.""" + return run_async_celery_task(lambda: _impl(run_id)) + + +async def _impl(run_id: int) -> None: + session_maker = get_celery_session_maker() + async with session_maker() as session: + try: + await execute_run(session, run_id) + except Exception: + logger.exception("automation_run %d failed unexpectedly", run_id) + await session.rollback() + raise diff --git a/surfsense_backend/app/automations/templating/__init__.py b/surfsense_backend/app/automations/templating/__init__.py new file mode 100644 index 000000000..1df1809c7 --- /dev/null +++ b/surfsense_backend/app/automations/templating/__init__.py @@ -0,0 +1,13 @@ +"""Sandboxed template engine for automation definitions.""" + +from __future__ import annotations + +from .context import build_run_context +from .render import evaluate_predicate, render_template, render_value + +__all__ = [ + "build_run_context", + "evaluate_predicate", + "render_template", + "render_value", +] diff --git a/surfsense_backend/app/automations/templating/allowlist.py b/surfsense_backend/app/automations/templating/allowlist.py new file mode 100644 index 000000000..ed0103c8f --- /dev/null +++ b/surfsense_backend/app/automations/templating/allowlist.py @@ -0,0 +1,31 @@ +"""Filter and test names admitted into the sandboxed environment.""" + +from __future__ import annotations + +ALLOWED_FILTERS: tuple[str, ...] = ( + "default", + "first", + "join", + "last", + "length", + "lower", + "replace", + "reverse", + "sort", + "tojson", + "trim", + "truncate", + "upper", + "date", + "slugify", +) + +ALLOWED_TESTS: tuple[str, ...] = ( + "defined", + "none", + "number", + "string", + "mapping", + "sequence", + "boolean", +) diff --git a/surfsense_backend/app/automations/templating/context.py b/surfsense_backend/app/automations/templating/context.py new file mode 100644 index 000000000..96fdb02e9 --- /dev/null +++ b/surfsense_backend/app/automations/templating/context.py @@ -0,0 +1,41 @@ +"""Builder for the ``{run, inputs, steps}`` namespace exposed to every template.""" + +from __future__ import annotations + +from collections.abc import Mapping +from datetime import datetime +from typing import Any + + +def build_run_context( + *, + run_id: int, + automation_id: int, + automation_name: str | None, + automation_version: int | None, + search_space_id: int | None, + creator_id: Any, + trigger_id: int | None, + trigger_type: str | None, + started_at: datetime | None, + attempt: int, + inputs: Mapping[str, Any], + step_outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Build the ``{run, inputs, steps}`` namespace exposed to every template.""" + return { + "run": { + "id": run_id, + "automation_id": automation_id, + "automation_name": automation_name, + "automation_version": automation_version, + "search_space_id": search_space_id, + "creator_id": creator_id, + "trigger_id": trigger_id, + "trigger_type": trigger_type, + "started_at": started_at, + "attempt": attempt, + }, + "inputs": dict(inputs), + "steps": dict(step_outputs), + } diff --git a/surfsense_backend/app/automations/templating/environment.py b/surfsense_backend/app/automations/templating/environment.py new file mode 100644 index 000000000..6ac5f7361 --- /dev/null +++ b/surfsense_backend/app/automations/templating/environment.py @@ -0,0 +1,43 @@ +"""SandboxedEnvironment construction with the audited filter/test allowlist.""" + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any + +from jinja2 import StrictUndefined +from jinja2.sandbox import SandboxedEnvironment + +from .allowlist import ALLOWED_FILTERS, ALLOWED_TESTS +from .filters import filter_date, filter_slugify + + +def _finalize(value: Any) -> Any: + """Stringify common non-string values at output sites.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, list | dict): + return json.dumps(value, ensure_ascii=False, default=str) + return value + + +def _build_env() -> SandboxedEnvironment: + env = SandboxedEnvironment( + autoescape=False, + undefined=StrictUndefined, + finalize=_finalize, + ) + env.globals.clear() + env.filters = {k: v for k, v in env.filters.items() if k in ALLOWED_FILTERS} + env.filters["date"] = filter_date + env.filters["slugify"] = filter_slugify + env.tests = {k: v for k, v in env.tests.items() if k in ALLOWED_TESTS} + return env + + +ENV: SandboxedEnvironment = _build_env() diff --git a/surfsense_backend/app/automations/templating/filters.py b/surfsense_backend/app/automations/templating/filters.py new file mode 100644 index 000000000..65f66eb37 --- /dev/null +++ b/surfsense_backend/app/automations/templating/filters.py @@ -0,0 +1,29 @@ +"""Custom Jinja filters registered into the sandboxed environment.""" + +from __future__ import annotations + +import re +from typing import Any + + +def filter_date(value: Any, fmt: str = "%Y-%m-%d") -> str: + """Format a datetime-like value with ``strftime``. Strings pass through.""" + if value is None: + return "" + if isinstance(value, str): + return value + if hasattr(value, "strftime"): + return value.strftime(fmt) + raise ValueError(f"date filter requires datetime-like, got {type(value).__name__}") + + +_SLUG_NONALNUM = re.compile(r"[^a-z0-9]+") +_SLUG_DASHES = re.compile(r"-+") + + +def filter_slugify(value: Any) -> str: + """Lowercase, replace non-alphanumerics with hyphens, collapse and trim.""" + s = str(value).lower() + s = _SLUG_NONALNUM.sub("-", s) + s = _SLUG_DASHES.sub("-", s) + return s.strip("-") diff --git a/surfsense_backend/app/automations/templating/render.py b/surfsense_backend/app/automations/templating/render.py new file mode 100644 index 000000000..42721ddeb --- /dev/null +++ b/surfsense_backend/app/automations/templating/render.py @@ -0,0 +1,29 @@ +"""Render templates and evaluate predicates against the sandboxed environment.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from .environment import ENV + + +def render_template(template: str, context: Mapping[str, Any]) -> str: + """Render ``template`` with ``context``.""" + return ENV.from_string(template).render(**context) + + +def evaluate_predicate(expression: str, context: Mapping[str, Any]) -> bool: + """Evaluate a Jinja expression (not a template body) and coerce to bool.""" + return bool(ENV.compile_expression(expression)(**context)) + + +def render_value(value: Any, context: Mapping[str, Any]) -> Any: + """Recursively render every string in a JSON-like value structure.""" + if isinstance(value, str): + return render_template(value, context) + if isinstance(value, dict): + return {k: render_value(v, context) for k, v in value.items()} + if isinstance(value, list): + return [render_value(v, context) for v in value] + return value diff --git a/surfsense_backend/app/automations/triggers/__init__.py b/surfsense_backend/app/automations/triggers/__init__.py new file mode 100644 index 000000000..9d28ddf5f --- /dev/null +++ b/surfsense_backend/app/automations/triggers/__init__.py @@ -0,0 +1,19 @@ +"""Triggers domain: registry surface + built-in trigger packages. + +Built-in trigger types live under ``builtin/`` and self-register at import time. +""" + +from __future__ import annotations + +from .store import all_triggers, get_trigger, register_trigger +from .types import TriggerDefinition + +__all__ = [ + "TriggerDefinition", + "all_triggers", + "get_trigger", + "register_trigger", +] + +# Built-in triggers self-register at import time. +from . import builtin # noqa: F401 diff --git a/surfsense_backend/app/automations/triggers/builtin/__init__.py b/surfsense_backend/app/automations/triggers/builtin/__init__.py new file mode 100644 index 000000000..17d8e914b --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/__init__.py @@ -0,0 +1,5 @@ +"""Built-in trigger types — each in its own subpackage, self-registering at import.""" + +from __future__ import annotations + +from . import event, schedule # noqa: F401 diff --git a/surfsense_backend/app/automations/triggers/builtin/event/__init__.py b/surfsense_backend/app/automations/triggers/builtin/event/__init__.py new file mode 100644 index 000000000..8dc89dfa1 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/__init__.py @@ -0,0 +1,29 @@ +"""``event`` trigger: fire an automation when a matching domain event is published. + +Subscribes to the event bus and matches events against a user-authored JSON +filter (see :mod:`.filter`). +""" + +from __future__ import annotations + +from app.event_bus import bus + +from .filter import FilterError, matches +from .inputs import event_runtime_inputs +from .match import trigger_matches_event +from .params import EventTriggerParams +from .source import on_event + +__all__ = [ + "EventTriggerParams", + "FilterError", + "event_runtime_inputs", + "matches", + "trigger_matches_event", +] + +# Side-effect: register on the triggers store. +from . import definition # noqa: F401 + +# Side-effect: react to published events. +bus.subscribe(on_event) diff --git a/surfsense_backend/app/automations/triggers/builtin/event/definition.py b/surfsense_backend/app/automations/triggers/builtin/event/definition.py new file mode 100644 index 000000000..b1ef6d4e2 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/definition.py @@ -0,0 +1,16 @@ +"""``event`` ``TriggerDefinition`` registration.""" + +from __future__ import annotations + +from app.automations.triggers.store import register_trigger +from app.automations.triggers.types import TriggerDefinition + +from .params import EventTriggerParams + +EVENT_TRIGGER = TriggerDefinition( + type="event", + description="Fire when a matching domain event is published.", + params_model=EventTriggerParams, +) + +register_trigger(EVENT_TRIGGER) diff --git a/surfsense_backend/app/automations/triggers/builtin/event/filter.py b/surfsense_backend/app/automations/triggers/builtin/event/filter.py new file mode 100644 index 000000000..742281fc6 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/filter.py @@ -0,0 +1,77 @@ +"""Pure JSON filter grammar: ``matches(filter_expr, payload) -> bool``. + +The ``event`` trigger uses it to decide whether an event fires the automation. +""" + +from __future__ import annotations + +import operator +from collections.abc import Callable +from typing import Any + + +class FilterError(ValueError): + """Unknown operator in a filter. Raised (not silently false) so a bad filter + fails at authoring time instead of quietly disabling the trigger.""" + + +# Scalar comparison operators: (actual, operand) -> bool. +_COMPARATORS: dict[str, Callable[[Any, Any], bool]] = { + "$eq": operator.eq, + "$ne": operator.ne, + "$gt": operator.gt, + "$gte": operator.ge, + "$lt": operator.lt, + "$lte": operator.le, + "$in": lambda actual, operand: actual in operand, + "$nin": lambda actual, operand: actual not in operand, +} + +# Sentinel for "the payload has no such field" — distinct from a present None. +_MISSING = object() + + +def matches(filter_expr: dict[str, Any], payload: dict[str, Any]) -> bool: + """Return ``True`` when ``payload`` satisfies every constraint in ``filter_expr``. + + An empty filter expresses "no constraints" and matches every payload. + Sibling keys (fields and logical operators alike) are ANDed together. + """ + for key, value in filter_expr.items(): + if key == "$and": + if not all(matches(sub, payload) for sub in value): + return False + elif key == "$or": + if not any(matches(sub, payload) for sub in value): + return False + elif key == "$not": + if matches(value, payload): + return False + elif key.startswith("$"): + raise FilterError(f"unknown logical operator: {key}") + elif not _match_condition(value, payload.get(key, _MISSING)): + return False + return True + + +def _match_condition(condition: Any, actual: Any) -> bool: + """Match one field's ``actual`` value against its ``condition``. + + A dict condition is an operator object (``{"$gt": 10}``); every operator in + it must hold. Any other value is an implicit equality check. A field absent + from the payload (``actual is _MISSING``) fails every constraint. + """ + if actual is _MISSING: + return False + if isinstance(condition, dict): + return all( + _apply_operator(op, operand, actual) for op, operand in condition.items() + ) + return actual == condition + + +def _apply_operator(op: str, operand: Any, actual: Any) -> bool: + comparator = _COMPARATORS.get(op) + if comparator is not None: + return comparator(actual, operand) + raise FilterError(f"unknown operator: {op}") diff --git a/surfsense_backend/app/automations/triggers/builtin/event/inputs.py b/surfsense_backend/app/automations/triggers/builtin/event/inputs.py new file mode 100644 index 000000000..e597c0b66 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/inputs.py @@ -0,0 +1,17 @@ +"""Build run inputs from a published event.""" + +from __future__ import annotations + +from typing import Any + +from app.event_bus import Event + + +def event_runtime_inputs(event: Event) -> dict[str, Any]: + """Flatten the event payload and stamp event metadata as run inputs.""" + return { + **event.payload, + "event_type": event.event_type, + "event_id": event.event_id, + "occurred_at": event.occurred_at.isoformat(), + } diff --git a/surfsense_backend/app/automations/triggers/builtin/event/match.py b/surfsense_backend/app/automations/triggers/builtin/event/match.py new file mode 100644 index 000000000..b67a3d49a --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/match.py @@ -0,0 +1,16 @@ +"""Pure predicate: does an event trigger fire for a given event?""" + +from __future__ import annotations + +from typing import Any + +from app.event_bus import Event + +from .filter import matches + + +def trigger_matches_event(params: dict[str, Any], event: Event) -> bool: + """True when an event trigger configured with ``params`` should fire for ``event``.""" + if params.get("event_type") != event.event_type: + return False + return matches(params.get("filter") or {}, event.payload) diff --git a/surfsense_backend/app/automations/triggers/builtin/event/params.py b/surfsense_backend/app/automations/triggers/builtin/event/params.py new file mode 100644 index 000000000..cd28702c0 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/params.py @@ -0,0 +1,23 @@ +"""``EventTriggerParams`` — params for the ``event`` trigger type.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class EventTriggerParams(BaseModel): + model_config = ConfigDict(extra="forbid") + + event_type: str = Field( + ..., + min_length=1, + description="Event type to listen for.", + examples=["document.indexed"], + ) + filter: dict[str, Any] = Field( + default_factory=dict, + description="JSON filter matched against the event payload.", + examples=[{"document_type": "FILE"}], + ) diff --git a/surfsense_backend/app/automations/triggers/builtin/event/selector.py b/surfsense_backend/app/automations/triggers/builtin/event/selector.py new file mode 100644 index 000000000..ee00a6094 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/selector.py @@ -0,0 +1,73 @@ +"""Event selector (worker task): pick the triggers an event fires, start each. + +The source enqueues this with a serialized event. Here we load the enabled +``event`` triggers for that event type, keep the ones whose filter matches the +payload, and start a run for each. Per-trigger failures are isolated. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.dispatch import launch_run +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.persistence.models.trigger import AutomationTrigger +from app.celery_app import celery_app +from app.event_bus import Event +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +from .inputs import event_runtime_inputs +from .match import trigger_matches_event +from .source import TASK_NAME + +logger = logging.getLogger(__name__) + + +@celery_app.task(name=TASK_NAME) +def automation_event_select(event: dict[str, Any]) -> None: + """Select and start the runs an event fires.""" + return run_async_celery_task(lambda: _select_and_start(event)) + + +async def _select_and_start(event_dict: dict[str, Any]) -> None: + event = Event.model_validate(event_dict) + session_maker = get_celery_session_maker() + async with session_maker() as session: + for trigger in await _eligible(session, event=event): + await _start_one(session, trigger=trigger, event=event) + + +async def _eligible(session: AsyncSession, *, event: Event) -> list[AutomationTrigger]: + """Enabled ``event`` triggers for this event type whose filter matches.""" + stmt = select(AutomationTrigger).where( + AutomationTrigger.type == TriggerType.EVENT, + AutomationTrigger.enabled.is_(True), + AutomationTrigger.params["event_type"].astext == event.event_type, + ) + triggers = (await session.execute(stmt)).scalars().all() + return [t for t in triggers if trigger_matches_event(t.params, event)] + + +async def _start_one( + session: AsyncSession, *, trigger: AutomationTrigger, event: Event +) -> None: + try: + run = await launch_run( + session=session, + trigger=trigger, + runtime_inputs=event_runtime_inputs(event), + ) + logger.info( + "event fire: trigger=%d automation=%d run=%d event=%s", + trigger.id, + trigger.automation_id, + run.id, + event.event_id, + ) + except Exception: + logger.exception("event fire failed for trigger %d", trigger.id) + await session.rollback() diff --git a/surfsense_backend/app/automations/triggers/builtin/event/source.py b/surfsense_backend/app/automations/triggers/builtin/event/source.py new file mode 100644 index 000000000..b8e067b12 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/event/source.py @@ -0,0 +1,19 @@ +"""Event trigger source: the bus subscriber that enqueues the selector. + +Runs in whatever process published the event, so it stays thin — it only hands +the event to a worker (the selector does the DB matching). +""" + +from __future__ import annotations + +from app.event_bus import Event + +TASK_NAME = "automation_event_select" + + +async def on_event(event: Event) -> None: + """Enqueue the selector for ``event``.""" + # Lazy import: keeps app.celery_app out of the triggers-package import graph. + from app.celery_app import celery_app + + celery_app.send_task(TASK_NAME, kwargs={"event": event.model_dump(mode="json")}) diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/__init__.py b/surfsense_backend/app/automations/triggers/builtin/schedule/__init__.py new file mode 100644 index 000000000..0267b0577 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/__init__.py @@ -0,0 +1,16 @@ +"""``schedule`` trigger: fired on a cron schedule in a given timezone.""" + +from __future__ import annotations + +from .cron import InvalidCronError, compute_next_fire_at, validate_cron +from .params import ScheduleTriggerParams + +__all__ = [ + "InvalidCronError", + "ScheduleTriggerParams", + "compute_next_fire_at", + "validate_cron", +] + +# Side-effect: register on the triggers store. +from . import definition # noqa: F401 diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/cron.py b/surfsense_backend/app/automations/triggers/builtin/schedule/cron.py new file mode 100644 index 000000000..a8401e4a3 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/cron.py @@ -0,0 +1,41 @@ +"""Cron math for the ``schedule`` trigger: validate + advance ``next_fire_at``.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + +from croniter import CroniterBadCronError, croniter + + +class InvalidCronError(ValueError): + """Raised when a cron expression or timezone fails validation.""" + + +def validate_cron(cron: str, timezone: str) -> None: + """Raise ``InvalidCronError`` if cron or timezone are unusable.""" + try: + ZoneInfo(timezone) + except ZoneInfoNotFoundError as exc: + raise InvalidCronError(f"unknown timezone {timezone!r}") from exc + + try: + croniter(cron) + except (CroniterBadCronError, ValueError) as exc: + raise InvalidCronError(f"invalid cron {cron!r}: {exc}") from exc + + +def compute_next_fire_at(cron: str, timezone: str, *, after: datetime) -> datetime: + """Return the next moment matching ``cron`` in ``timezone`` strictly after ``after``. + + The result is normalized to UTC for storage. ``after`` is converted into the + given timezone before evaluation so DST and IANA rules apply correctly. + """ + tz = ZoneInfo(timezone) + base = ( + after.astimezone(tz) + if after.tzinfo + else after.replace(tzinfo=UTC).astimezone(tz) + ) + nxt: datetime = croniter(cron, base).get_next(datetime) + return nxt.astimezone(UTC) diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/definition.py b/surfsense_backend/app/automations/triggers/builtin/schedule/definition.py new file mode 100644 index 000000000..a6b0b9b8e --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/definition.py @@ -0,0 +1,16 @@ +"""``schedule`` ``TriggerDefinition`` registration.""" + +from __future__ import annotations + +from app.automations.triggers.store import register_trigger +from app.automations.triggers.types import TriggerDefinition + +from .params import ScheduleTriggerParams + +SCHEDULE_TRIGGER = TriggerDefinition( + type="schedule", + description="Fire on a cron schedule in a given timezone.", + params_model=ScheduleTriggerParams, +) + +register_trigger(SCHEDULE_TRIGGER) diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/inputs.py b/surfsense_backend/app/automations/triggers/builtin/schedule/inputs.py new file mode 100644 index 000000000..947975b28 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/inputs.py @@ -0,0 +1,27 @@ +"""Build run inputs from a schedule fire.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + + +def schedule_runtime_inputs( + *, + fired_at: datetime, + scheduled_for: datetime, + previous_last_fired_at: datetime | None, +) -> dict[str, Any]: + """Calendar context for a scheduled run. + + - ``fired_at`` — actual fire time + - ``scheduled_for`` — cron-derived target time for this fire + - ``last_fired_at`` — previous fire time, or null on first fire + """ + return { + "fired_at": fired_at.isoformat(), + "scheduled_for": scheduled_for.isoformat(), + "last_fired_at": ( + previous_last_fired_at.isoformat() if previous_last_fired_at else None + ), + } diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/params.py b/surfsense_backend/app/automations/triggers/builtin/schedule/params.py new file mode 100644 index 000000000..f3945a1b8 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/params.py @@ -0,0 +1,24 @@ +"""``ScheduleTriggerParams`` — params for the ``schedule`` trigger type.""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .cron import InvalidCronError, validate_cron + + +class ScheduleTriggerParams(BaseModel): + model_config = ConfigDict(extra="forbid") + + cron: str = Field( + ..., description="Five-field cron expression.", examples=["0 9 * * 1-5"] + ) + timezone: str = Field(..., description="IANA timezone.", examples=["Africa/Kigali"]) + + @model_validator(mode="after") + def _validate(self) -> ScheduleTriggerParams: + try: + validate_cron(self.cron, self.timezone) + except InvalidCronError as exc: + raise ValueError(str(exc)) from exc + return self diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/selector.py b/surfsense_backend/app/automations/triggers/builtin/schedule/selector.py new file mode 100644 index 000000000..be592fe12 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/selector.py @@ -0,0 +1,182 @@ +"""Schedule selector (worker task): claim due triggers and start each. + +Beat ticks this every minute. Two passes: + +1. **Self-heal**: enabled schedule triggers with NULL ``next_fire_at`` get it + computed from their ``cron`` + ``timezone`` (fresh inserts, restored rows). +2. **Claim & start**: due rows are locked ``FOR UPDATE SKIP LOCKED``, their + ``next_fire_at`` is advanced and ``last_fired_at`` set, and a run is started + for each. A missed fire stays missed (``catchup=False`` semantics). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.dispatch import launch_run +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.persistence.models.trigger import AutomationTrigger +from app.celery_app import celery_app +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +from .cron import InvalidCronError, compute_next_fire_at +from .inputs import schedule_runtime_inputs +from .source import TASK_NAME + +logger = logging.getLogger(__name__) + +# Cap rows touched per tick so a backlog of due triggers can't starve the +# worker; remaining rows fire on the next tick. +_TICK_BATCH = 200 + + +@dataclass(frozen=True, slots=True) +class _Claim: + """Per-trigger fire context captured before row state is mutated.""" + + trigger_id: int + scheduled_for: datetime + previous_last_fired_at: datetime | None + + +@celery_app.task(name=TASK_NAME) +def automation_schedule_select() -> None: + """Tick once: self-heal NULL next_fire_at, claim due rows, start each.""" + return run_async_celery_task(_tick) + + +async def _tick() -> None: + session_maker = get_celery_session_maker() + async with session_maker() as session: + now = datetime.now(UTC) + + await _self_heal_null_next_fire(session, now=now) + + claims = await _claim_due_triggers(session, now=now) + if not claims: + return + + for claim in claims: + await _start_one(session, claim=claim, fired_at=now) + + +async def _self_heal_null_next_fire(session: AsyncSession, *, now: datetime) -> None: + """Backfill ``next_fire_at`` for enabled schedule triggers missing it.""" + stmt = ( + select(AutomationTrigger) + .where( + AutomationTrigger.type == TriggerType.SCHEDULE, + AutomationTrigger.enabled.is_(True), + AutomationTrigger.next_fire_at.is_(None), + ) + .limit(_TICK_BATCH) + ) + triggers = (await session.execute(stmt)).scalars().all() + if not triggers: + return + + for trigger in triggers: + try: + trigger.next_fire_at = compute_next_fire_at( + trigger.params["cron"], + trigger.params["timezone"], + after=now, + ) + except (InvalidCronError, KeyError, TypeError) as exc: + logger.warning( + "automation_trigger %d has invalid schedule params, disabling: %s", + trigger.id, + exc, + ) + trigger.enabled = False + + await session.commit() + + +async def _claim_due_triggers(session: AsyncSession, *, now: datetime) -> list[_Claim]: + """Lock and advance due rows; return per-trigger fire context.""" + stmt = ( + select(AutomationTrigger) + .where( + AutomationTrigger.type == TriggerType.SCHEDULE, + AutomationTrigger.enabled.is_(True), + AutomationTrigger.next_fire_at.isnot(None), + AutomationTrigger.next_fire_at <= now, + ) + .order_by(AutomationTrigger.next_fire_at) + .limit(_TICK_BATCH) + .with_for_update(skip_locked=True) + ) + triggers = (await session.execute(stmt)).scalars().all() + if not triggers: + return [] + + claims: list[_Claim] = [] + for trigger in triggers: + # Snapshot fire-context BEFORE we advance the row. + scheduled_for = trigger.next_fire_at + previous_last_fired_at = trigger.last_fired_at + + try: + trigger.next_fire_at = compute_next_fire_at( + trigger.params["cron"], + trigger.params["timezone"], + after=now, + ) + except (InvalidCronError, KeyError, TypeError) as exc: + logger.warning( + "automation_trigger %d has invalid schedule params, disabling: %s", + trigger.id, + exc, + ) + trigger.enabled = False + continue + + trigger.last_fired_at = now + claims.append( + _Claim( + trigger_id=trigger.id, + scheduled_for=scheduled_for, + previous_last_fired_at=previous_last_fired_at, + ) + ) + + await session.commit() + return claims + + +async def _start_one( + session: AsyncSession, *, claim: _Claim, fired_at: datetime +) -> None: + """Reload the trigger post-commit and start a run for it.""" + trigger = await session.get(AutomationTrigger, claim.trigger_id) + if trigger is None: + return + + try: + run = await launch_run( + session=session, + trigger=trigger, + runtime_inputs=schedule_runtime_inputs( + fired_at=fired_at, + scheduled_for=claim.scheduled_for, + previous_last_fired_at=claim.previous_last_fired_at, + ), + ) + logger.info( + "scheduled fire: trigger=%d automation=%d run=%d", + claim.trigger_id, + trigger.automation_id, + run.id, + ) + except Exception: + logger.exception( + "scheduled fire failed for trigger %d (next attempt at next match)", + claim.trigger_id, + ) + await session.rollback() diff --git a/surfsense_backend/app/automations/triggers/builtin/schedule/source.py b/surfsense_backend/app/automations/triggers/builtin/schedule/source.py new file mode 100644 index 000000000..997c17562 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/builtin/schedule/source.py @@ -0,0 +1,20 @@ +"""Schedule trigger source: Celery Beat ticks the selector every minute. + +``BEAT_SCHEDULE`` is merged into ``celery_app.conf.beat_schedule``. Per-row cron +math is precomputed (the ``next_fire_at`` column), so each tick is an indexed +lookup rather than N cron evaluations. +""" + +from __future__ import annotations + +from celery.schedules import crontab + +TASK_NAME = "automation_schedule_select" + +BEAT_SCHEDULE = { + "automation-schedule-select": { + "task": TASK_NAME, + "schedule": crontab(minute="*"), + "options": {"expires": 50}, + }, +} diff --git a/surfsense_backend/app/automations/triggers/store.py b/surfsense_backend/app/automations/triggers/store.py new file mode 100644 index 000000000..af0fafac7 --- /dev/null +++ b/surfsense_backend/app/automations/triggers/store.py @@ -0,0 +1,23 @@ +"""In-memory trigger registry. Populated once at process startup.""" + +from __future__ import annotations + +from .types import TriggerDefinition + +_REGISTRY: dict[str, TriggerDefinition] = {} + + +def register_trigger(trigger: TriggerDefinition) -> None: + """Register a trigger. Raises on duplicate type.""" + if trigger.type in _REGISTRY: + raise ValueError(f"Trigger already registered: {trigger.type!r}") + _REGISTRY[trigger.type] = trigger + + +def get_trigger(trigger_type: str) -> TriggerDefinition | None: + return _REGISTRY.get(trigger_type) + + +def all_triggers() -> dict[str, TriggerDefinition]: + """Defensive snapshot of the registry.""" + return dict(_REGISTRY) diff --git a/surfsense_backend/app/automations/triggers/types.py b/surfsense_backend/app/automations/triggers/types.py new file mode 100644 index 000000000..aa2808e4d --- /dev/null +++ b/surfsense_backend/app/automations/triggers/types.py @@ -0,0 +1,20 @@ +"""``TriggerDefinition`` dataclass. Declarative; firing is the dispatcher's job.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + + +@dataclass(frozen=True, slots=True) +class TriggerDefinition: + type: str + description: str + params_model: type[BaseModel] + + @property + def params_schema(self) -> dict[str, Any]: + """JSON Schema (draft 2020-12) derived from ``params_model``.""" + return self.params_model.model_json_schema() diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 74710d5e1..99e34e8ca 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -1,16 +1,103 @@ """Celery application configuration and setup.""" +import contextlib import os +import time from celery import Celery from celery.schedules import crontab -from celery.signals import worker_process_init +from celery.signals import ( + before_task_publish, + task_postrun, + task_prerun, + worker_process_init, +) from dotenv import load_dotenv +try: + from opentelemetry import trace +except ImportError: # pragma: no cover - optional OTel dependency + trace = None # type: ignore[assignment] + # Load environment variables load_dotenv() +@before_task_publish.connect +def _stamp_enqueue_time(headers=None, **_kwargs): + """Stamp enqueue time so workers can measure queue wait.""" + if headers is None: + return + with contextlib.suppress(Exception): + headers["surfsense.enqueued_at_ns"] = str(time.monotonic_ns()) + + +@task_prerun.connect +def _record_queue_latency(task=None, **_kwargs): + """Record queue latency and stash task metadata for span enrichment.""" + if task is None: + return + try: + from app.observability import metrics as ot_metrics + + task_name = getattr(task, "name", None) or "unknown" + operation = ot_metrics.parse_celery_task_label(task_name) + request = getattr(task, "request", None) + delivery_info = getattr(request, "delivery_info", None) or {} + queue = delivery_info.get("routing_key") or "unknown" + scheduled = bool( + getattr(request, "eta", None) or getattr(request, "expires", None) + ) + + with contextlib.suppress(Exception): + request.surfsense_operation = operation + request.surfsense_queue = queue + request.surfsense_scheduled = scheduled + + headers = getattr(request, "headers", None) or {} + enqueued_ns = headers.get("surfsense.enqueued_at_ns") + if enqueued_ns is None: + return + + elapsed_s = (time.monotonic_ns() - int(enqueued_ns)) / 1e9 + with contextlib.suppress(Exception): + request.surfsense_queue_latency_ms = elapsed_s * 1000 + + ot_metrics.record_celery_queue_latency( + elapsed_s, + task_name=task_name, + queue=queue, + scheduled=scheduled, + operation=operation, + ) + except Exception: + pass + + +@task_postrun.connect +def _set_celery_span_attributes(task=None, **_kwargs): + """Attach derived queue metadata to the active Celery run span.""" + if task is None or trace is None: + return + + try: + request = getattr(task, "request", None) + if request is None: + return + + span = trace.get_current_span() + + operation = getattr(request, "surfsense_operation", None) + if operation: + span.set_attribute("celery.task.operation", operation) + + latency_ms = getattr(request, "surfsense_queue_latency_ms", None) + if latency_ms is not None: + span.set_attribute("celery.queue.latency_ms", latency_ms) + except Exception: + pass + + @worker_process_init.connect def init_worker(**kwargs): """Initialize the LLM Router and Image Gen Router when a Celery worker process starts. @@ -18,6 +105,10 @@ def init_worker(**kwargs): This ensures the Auto mode (LiteLLM Router) is available for background tasks like document summarization and image generation. """ + from app.observability.bootstrap import init_otel + + init_otel(app=None, traces=True, metrics=True, logs=True) + from app.config import ( initialize_image_gen_router, initialize_llm_router, @@ -97,6 +188,9 @@ celery_app = Celery( "app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.stale_notification_cleanup_task", "app.tasks.celery_tasks.stripe_reconciliation_task", + "app.automations.tasks.execute_run", + "app.automations.triggers.builtin.schedule.selector", + "app.automations.triggers.builtin.event.selector", ], ) @@ -154,6 +248,12 @@ celery_app.conf.update( }, ) +# Imported late (after celery_app is built) to keep the automations triggers +# package out of this module's top-level import graph. +from app.automations.triggers.builtin.schedule.source import ( # noqa: E402 + BEAT_SCHEDULE as SCHEDULE_BEAT_SCHEDULE, +) + # Configure Celery Beat schedule # This uses a meta-scheduler pattern: instead of creating individual Beat schedules # for each connector, we have ONE schedule that checks the database at the configured interval @@ -191,4 +291,7 @@ celery_app.conf.beat_schedule = { "expires": 60, }, }, + # Fire due automation schedule triggers (Beat entry owned by the schedule + # trigger; see app.automations.triggers.builtin.schedule.source). + **SCHEDULE_BEAT_SCHEDULE, } diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9fc27fb1f..d6ee9ff88 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -439,6 +439,13 @@ class Permission(StrEnum): PUBLIC_SHARING_CREATE = "public_sharing:create" PUBLIC_SHARING_DELETE = "public_sharing:delete" + # Automations + AUTOMATIONS_CREATE = "automations:create" + AUTOMATIONS_READ = "automations:read" + AUTOMATIONS_UPDATE = "automations:update" + AUTOMATIONS_DELETE = "automations:delete" + AUTOMATIONS_EXECUTE = "automations:execute" + # Full access wildcard FULL_ACCESS = "*" @@ -494,6 +501,11 @@ DEFAULT_ROLE_PERMISSIONS = { # Public Sharing (can create and view, no delete) Permission.PUBLIC_SHARING_VIEW.value, Permission.PUBLIC_SHARING_CREATE.value, + # Automations (no delete) + Permission.AUTOMATIONS_CREATE.value, + Permission.AUTOMATIONS_READ.value, + Permission.AUTOMATIONS_UPDATE.value, + Permission.AUTOMATIONS_EXECUTE.value, ], "Viewer": [ # Documents (read only) @@ -525,6 +537,8 @@ DEFAULT_ROLE_PERMISSIONS = { Permission.SETTINGS_VIEW.value, # Public Sharing (view only) Permission.PUBLIC_SHARING_VIEW.value, + # Automations (read only) + Permission.AUTOMATIONS_READ.value, ], } @@ -1136,46 +1150,6 @@ class Chunk(BaseModel, TimestampMixin): document = relationship("Document", back_populates="chunks") -class SurfsenseDocsDocument(BaseModel, TimestampMixin): - """ - Surfsense documentation storage. - Indexed at migration time from MDX files. - """ - - __tablename__ = "surfsense_docs_documents" - - source = Column( - String, nullable=False, unique=True, index=True - ) # File path: "connectors/slack.mdx" - title = Column(String, nullable=False) - content = Column(Text, nullable=False) - content_hash = Column(String, nullable=False, index=True) # For detecting changes - embedding = Column(Vector(config.embedding_model_instance.dimension)) - updated_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) - - chunks = relationship( - "SurfsenseDocsChunk", - back_populates="document", - cascade="all, delete-orphan", - ) - - -class SurfsenseDocsChunk(BaseModel, TimestampMixin): - """Chunk storage for Surfsense documentation.""" - - __tablename__ = "surfsense_docs_chunks" - - content = Column(Text, nullable=False) - embedding = Column(Vector(config.embedding_model_instance.dimension)) - - document_id = Column( - Integer, - ForeignKey("surfsense_docs_documents.id", ondelete="CASCADE"), - nullable=False, - ) - document = relationship("SurfsenseDocsDocument", back_populates="chunks") - - class Podcast(BaseModel, TimestampMixin): """Podcast model for storing generated podcasts.""" @@ -1533,6 +1507,14 @@ class SearchSpace(BaseModel, TimestampMixin): cascade="all, delete-orphan", ) + automations = relationship( + "Automation", + back_populates="search_space", + order_by="Automation.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) + # RBAC relationships roles = relationship( "SearchSpaceRole", @@ -2125,6 +2107,13 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) + # Automations created by this user + automations = relationship( + "Automation", + back_populates="created_by", + passive_deletes=True, + ) + # Incentive tasks completed by this user incentive_tasks = relationship( "UserIncentiveTask", @@ -2257,6 +2246,13 @@ else: passive_deletes=True, ) + # Automations created by this user + automations = relationship( + "Automation", + back_populates="created_by", + passive_deletes=True, + ) + # Incentive tasks completed by this user incentive_tasks = relationship( "UserIncentiveTask", @@ -2560,6 +2556,15 @@ class RefreshToken(Base, TimestampMixin): return not self.is_expired and not self.is_revoked +# Register model packages that live outside this file so their classes +# are present in Base.metadata before configure_mappers() resolves any +# string-based relationship() references. +from app.automations.persistence import ( # noqa: E402, F401 + Automation, + AutomationRun, + AutomationTrigger, +) + engine = create_async_engine( DATABASE_URL, pool_size=30, @@ -2635,11 +2640,6 @@ async def setup_indexes(): "CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)" ) ) - await conn.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_surfsense_docs_title_trgm ON surfsense_docs_documents USING gin (title gin_trgm_ops)" - ) - ) async def create_db_and_tables(): diff --git a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py index 87e8138fd..a2f4d0bbd 100644 --- a/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py +++ b/surfsense_backend/app/etl_pipeline/etl_pipeline_service.py @@ -1,4 +1,7 @@ +import contextlib import logging +import time +from pathlib import PurePosixPath from app.config import config as app_config from app.etl_pipeline.etl_document import EtlRequest, EtlResult @@ -10,6 +13,11 @@ from app.etl_pipeline.file_classifier import FileCategory, classify_file from app.etl_pipeline.parsers.audio import transcribe_audio from app.etl_pipeline.parsers.direct_convert import convert_file_directly from app.etl_pipeline.parsers.plaintext import read_plaintext +from app.observability import metrics as ot_metrics, otel as ot + + +def _file_extension(filename: str) -> str: + return PurePosixPath(filename).suffix.lower() or "none" class EtlPipelineService: @@ -20,49 +28,93 @@ class EtlPipelineService: async def extract(self, request: EtlRequest) -> EtlResult: category = classify_file(request.filename) + start = time.perf_counter() + status = "success" + error_category: str | None = None + result: EtlResult | None = None + with ot.etl_extract_span( + content_type=category.value, + file_extension=_file_extension(request.filename), + processing_mode=request.processing_mode.value, + ) as sp: + try: + if category == FileCategory.UNSUPPORTED: + raise EtlUnsupportedFileError( + f"File type not supported for parsing: {request.filename}" + ) - if category == FileCategory.UNSUPPORTED: - raise EtlUnsupportedFileError( - f"File type not supported for parsing: {request.filename}" - ) + if category == FileCategory.PLAINTEXT: + content = read_plaintext(request.file_path) + result = EtlResult( + markdown_content=content, + etl_service="PLAINTEXT", + content_type="plaintext", + ) + return result - if category == FileCategory.PLAINTEXT: - content = read_plaintext(request.file_path) - return EtlResult( - markdown_content=content, - etl_service="PLAINTEXT", - content_type="plaintext", - ) + if category == FileCategory.DIRECT_CONVERT: + content = convert_file_directly(request.file_path, request.filename) + result = EtlResult( + markdown_content=content, + etl_service="DIRECT_CONVERT", + content_type="direct_convert", + ) + return result - if category == FileCategory.DIRECT_CONVERT: - content = convert_file_directly(request.file_path, request.filename) - return EtlResult( - markdown_content=content, - etl_service="DIRECT_CONVERT", - content_type="direct_convert", - ) + if category == FileCategory.AUDIO: + content = await transcribe_audio( + request.file_path, request.filename + ) + result = EtlResult( + markdown_content=content, + etl_service="AUDIO", + content_type="audio", + ) + return result - if category == FileCategory.AUDIO: - content = await transcribe_audio(request.file_path, request.filename) - return EtlResult( - markdown_content=content, - etl_service="AUDIO", - content_type="audio", - ) + if category == FileCategory.IMAGE: + result = await self._extract_image(request) + return result - if category == FileCategory.IMAGE: - return await self._extract_image(request) - - return await self._extract_document(request) + result = await self._extract_document(request) + return result + except Exception as exc: + status = "error" + error_category = ot_metrics.categorize_exception(exc) + raise + finally: + with contextlib.suppress(Exception): + if result is not None: + sp.set_attribute("etl.service", result.etl_service) + sp.set_attribute("content.type", result.content_type) + sp.set_attribute("etl.status", status) + ot_metrics.record_etl_extract_duration( + time.perf_counter() - start, + etl_service=result.etl_service if result else None, + content_type=result.content_type if result else category.value, + status=status, + ) + ot_metrics.record_etl_extract_outcome( + etl_service=result.etl_service if result else None, + content_type=result.content_type if result else category.value, + status=status, + error_category=error_category, + ) async def _extract_image(self, request: EtlRequest) -> EtlResult: if self._vision_llm: try: from app.etl_pipeline.parsers.vision_llm import parse_with_vision_llm - content = await parse_with_vision_llm( - request.file_path, request.filename, self._vision_llm - ) + with ot.etl_parse_span( + etl_service="VISION_LLM", + content_type="image", + file_extension=_file_extension(request.filename), + ) as sp: + content = await parse_with_vision_llm( + request.file_path, request.filename, self._vision_llm + ) + sp.set_attribute("etl.status", "success") return EtlResult( markdown_content=content, etl_service="VISION_LLM", @@ -87,14 +139,34 @@ class EtlPipelineService: request.filename, exc_info=True, ) + ot.add_event( + "etl.fallback", + { + "fallback.from": "vision_llm", + "fallback.to": "document_parser", + "fallback.reason": ot_metrics.categorize_exception(exc), + }, + ) else: logging.info( "No vision LLM provided, falling back to document parser for %s", request.filename, ) + ot.add_event( + "etl.fallback", + { + "fallback.from": "vision_llm", + "fallback.to": "document_parser", + "fallback.reason": "not_configured", + }, + ) try: - return await self._extract_document(request) + with ot.etl_ocr_span( + etl_service=app_config.ETL_SERVICE, + file_extension=_file_extension(request.filename), + ): + return await self._extract_document(request) except (EtlUnsupportedFileError, EtlServiceUnavailableError): raise EtlUnsupportedFileError( f"Cannot process image {request.filename}: vision LLM " @@ -121,18 +193,27 @@ class EtlPipelineService: f"File type {ext} is not supported by {etl_service}" ) - if etl_service == "DOCLING": - from app.etl_pipeline.parsers.docling import parse_with_docling + with ot.etl_parse_span( + etl_service=etl_service, + content_type="document", + file_extension=ext, + processing_mode=request.processing_mode.value, + ) as sp: + if etl_service == "DOCLING": + from app.etl_pipeline.parsers.docling import parse_with_docling - content = await parse_with_docling(request.file_path, request.filename) - elif etl_service == "UNSTRUCTURED": - from app.etl_pipeline.parsers.unstructured import parse_with_unstructured + content = await parse_with_docling(request.file_path, request.filename) + elif etl_service == "UNSTRUCTURED": + from app.etl_pipeline.parsers.unstructured import ( + parse_with_unstructured, + ) - content = await parse_with_unstructured(request.file_path) - elif etl_service == "LLAMACLOUD": - content = await self._extract_with_llamacloud(request) - else: - raise EtlServiceUnavailableError(f"Unknown ETL_SERVICE: {etl_service}") + content = await parse_with_unstructured(request.file_path) + elif etl_service == "LLAMACLOUD": + content = await self._extract_with_llamacloud(request) + else: + raise EtlServiceUnavailableError(f"Unknown ETL_SERVICE: {etl_service}") + sp.set_attribute("etl.status", "success") # When the operator opts into vision-LLM at ingest, walk the # original file's embedded images and append a structured @@ -171,9 +252,14 @@ class EtlPipelineService: async def _ocr_image(image_path: str, image_name: str) -> str: try: sub = EtlPipelineService(vision_llm=None) - ocr_result = await sub.extract( - EtlRequest(file_path=image_path, filename=image_name) - ) + with ot.etl_picture_ocr_span( + file_extension=_file_extension(image_name) + ) as sp: + ocr_result = await sub.extract( + EtlRequest(file_path=image_path, filename=image_name) + ) + sp.set_attribute("etl.service", ocr_result.etl_service) + sp.set_attribute("etl.status", "success") except ( EtlUnsupportedFileError, EtlServiceUnavailableError, @@ -181,20 +267,42 @@ class EtlPipelineService: # Common case: the configured ETL service can't OCR # this image format (or no service is configured at # all). Don't spam warnings -- just no OCR for it. + ot.add_event( + "etl.ocr.skipped", + { + "skip.reason": "unsupported_format", + "error.category": ot_metrics.categorize_exception(exc), + }, + ) logging.debug("Skipping per-image OCR for %s: %s", image_name, exc) return "" return ocr_result.markdown_content try: - result = await describe_pictures( - request.file_path, - request.filename, - self._vision_llm, - ocr_runner=_ocr_image, - ) - except Exception: + with ot.etl_picture_describe_span() as sp: + result = await describe_pictures( + request.file_path, + request.filename, + self._vision_llm, + ocr_runner=_ocr_image, + ) + sp.set_attribute("image.described.count", len(result.descriptions)) + sp.set_attribute("image.failed.count", result.failed) + sp.set_attribute("image.skipped.too_small", result.skipped_too_small) + sp.set_attribute("image.skipped.too_large", result.skipped_too_large) + sp.set_attribute("image.skipped.duplicate", result.skipped_duplicate) + sp.set_attribute("etl.status", "success") + except Exception as exc: # Picture description is additive; never let it fail an # otherwise-successful document extraction. + ot.add_event( + "etl.degraded", + { + "degraded.reason": "picture_describe_failed", + "degraded.action": "return_parser_output", + "error.category": ot_metrics.categorize_exception(exc), + }, + ) logging.warning( "Picture description failed for %s, returning parser output unchanged", request.filename, @@ -247,7 +355,15 @@ class EtlPipelineService: return await parse_with_azure_doc_intelligence( request.file_path, processing_mode=mode_value ) - except Exception: + except Exception as exc: + ot.add_event( + "etl.fallback", + { + "fallback.from": "azure_di", + "fallback.to": "llamacloud", + "fallback.reason": ot_metrics.categorize_exception(exc), + }, + ) logging.warning( "Azure Document Intelligence failed for %s, " "falling back to LlamaCloud", diff --git a/surfsense_backend/app/event_bus/__init__.py b/surfsense_backend/app/event_bus/__init__.py new file mode 100644 index 000000000..da5735fe6 --- /dev/null +++ b/surfsense_backend/app/event_bus/__init__.py @@ -0,0 +1,25 @@ +"""In-process domain event bus. + +Domain-agnostic pub/sub. Producers ``await bus.publish(...)``; subscribers +``bus.subscribe(...)``. Domain modules depend on it, never the reverse. + + from app.event_bus import bus + await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7) +""" + +from __future__ import annotations + +from . import events # noqa: F401 — populates the event-type catalog +from .bus import EventBus, Subscriber, bus +from .catalog import EventCatalog, EventType, catalog +from .event import Event + +__all__ = [ + "Event", + "EventBus", + "EventCatalog", + "EventType", + "Subscriber", + "bus", + "catalog", +] diff --git a/surfsense_backend/app/event_bus/bus.py b/surfsense_backend/app/event_bus/bus.py new file mode 100644 index 000000000..38c93ba7c --- /dev/null +++ b/surfsense_backend/app/event_bus/bus.py @@ -0,0 +1,77 @@ +"""In-process pub/sub. Streams :class:`Event` values from producers to listeners. + +Boundary-crossing (Celery, DB, workers) is a subscriber's job — e.g. the +``event`` trigger enqueues its own task. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from .event import Event + +logger = logging.getLogger(__name__) + +Subscriber = Callable[[Event], Awaitable[None]] + + +class EventBus: + """An in-process pub/sub bus with a per-instance subscriber registry.""" + + def __init__(self) -> None: + self._subscribers: list[Subscriber] = [] + + def subscribe(self, handler: Subscriber) -> Subscriber: + """Register ``handler`` for every event. Idempotent; returns the handler + so it works as a decorator.""" + if handler not in self._subscribers: + self._subscribers.append(handler) + return handler + + def subscribers(self) -> list[Subscriber]: + """Defensive snapshot of the registered subscribers.""" + return list(self._subscribers) + + async def publish( + self, + event_type: str, + payload: dict[str, Any] | None = None, + *, + search_space_id: int, + ) -> None: + """Stamp an :class:`Event` and fan it out. Call after your commit.""" + event = Event( + event_type=event_type, + payload=payload or {}, + search_space_id=search_space_id, + ) + await self.dispatch(event) + + async def dispatch(self, event: Event) -> None: + """Fan ``event`` out concurrently. Subscriber failures are logged and + isolated; never propagate.""" + subscribers = self.subscribers() + if not subscribers: + return + + results = await asyncio.gather( + *(handler(event) for handler in subscribers), + return_exceptions=True, + ) + + for handler, result in zip(subscribers, results, strict=True): + if isinstance(result, Exception): + logger.error( + "event subscriber %r failed for event %s (%s)", + getattr(handler, "__qualname__", handler), + event.event_id, + event.event_type, + exc_info=result, + ) + + +# Process-wide bus. Producers publish to it; subscribers register on it. +bus = EventBus() diff --git a/surfsense_backend/app/event_bus/catalog.py b/surfsense_backend/app/event_bus/catalog.py new file mode 100644 index 000000000..a50be689f --- /dev/null +++ b/surfsense_backend/app/event_bus/catalog.py @@ -0,0 +1,48 @@ +"""Event type catalog: the deliberate contract behind each event. + +``EventType`` declares a dotted name and the shape of its payload. +``EventCatalog`` is the registry — populated once at import by each event type +module. ``catalog`` is the process-wide singleton. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + + +@dataclass(frozen=True, slots=True) +class EventType: + type: str + description: str + payload_model: type[BaseModel] + + @property + def payload_schema(self) -> dict[str, Any]: + """JSON Schema (draft 2020-12) derived from ``payload_model``.""" + return self.payload_model.model_json_schema() + + +class EventCatalog: + """Registry of known event types. Populated at import; read at runtime.""" + + def __init__(self) -> None: + self._registry: dict[str, EventType] = {} + + def register(self, event_type: EventType) -> None: + """Register an event type. Raises on duplicate type.""" + if event_type.type in self._registry: + raise ValueError(f"Event type already registered: {event_type.type!r}") + self._registry[event_type.type] = event_type + + def get(self, type_: str) -> EventType | None: + return self._registry.get(type_) + + def all(self) -> dict[str, EventType]: + """Defensive snapshot of the registry.""" + return dict(self._registry) + + +catalog = EventCatalog() diff --git a/surfsense_backend/app/event_bus/event.py b/surfsense_backend/app/event_bus/event.py new file mode 100644 index 000000000..5dc3f7081 --- /dev/null +++ b/surfsense_backend/app/event_bus/event.py @@ -0,0 +1,38 @@ +"""The ``Event`` value object — the only shape that crosses the bus. + +An immutable fact: something named happened, with this payload, in this space, +at this time. JSON round-trippable so a subscriber can queue it to a worker. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +def _new_event_id() -> str: + return uuid.uuid4().hex + + +def _now() -> datetime: + return datetime.now(UTC) + + +class Event(BaseModel): + """A published domain fact. + + ``event_type`` is a dotted namespace (``document.indexed``, etc). ``payload`` is + JSON-serializable. ``search_space_id`` scopes delivery. ``event_id`` and + ``occurred_at`` are engine-stamped. + """ + + model_config = ConfigDict(frozen=True) + + event_type: str + payload: dict[str, Any] = Field(default_factory=dict) + search_space_id: int + event_id: str = Field(default_factory=_new_event_id) + occurred_at: datetime = Field(default_factory=_now) diff --git a/surfsense_backend/app/event_bus/events/__init__.py b/surfsense_backend/app/event_bus/events/__init__.py new file mode 100644 index 000000000..47c0e64c1 --- /dev/null +++ b/surfsense_backend/app/event_bus/events/__init__.py @@ -0,0 +1,5 @@ +"""Domain event type definitions — each in its own module, self-registering at import.""" + +from __future__ import annotations + +from . import document_entered_folder # noqa: F401 diff --git a/surfsense_backend/app/event_bus/events/document_entered_folder.py b/surfsense_backend/app/event_bus/events/document_entered_folder.py new file mode 100644 index 000000000..fc4e2de14 --- /dev/null +++ b/surfsense_backend/app/event_bus/events/document_entered_folder.py @@ -0,0 +1,86 @@ +"""``document.entered_folder``: a document became a member of a folder. + +Fires once per arrival, however the document got there (upload, AI sort, move). +The payload carries the fields a user can filter a trigger on. +""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, computed_field + +from app.event_bus.catalog import EventType, catalog + +EVENT_TYPE = "document.entered_folder" + + +class DocumentEnteredFolderPayload(BaseModel): + """Snapshot of the document at the moment it entered ``folder_id``. + + ``previous_folder_id`` is the folder it left, or ``None`` for a first + placement. ``is_move`` derives from it and is emitted for filtering. + """ + + model_config = ConfigDict(extra="forbid") + + document_id: int + folder_id: int + previous_folder_id: int | None = None + document_type: str + title: str + connector_id: int | None = None + created_by_id: str | None = None + + @computed_field + @property + def is_move(self) -> bool: + return self.previous_folder_id is not None + + +catalog.register( + EventType( + type=EVENT_TYPE, + description="A document became a member of a folder.", + payload_model=DocumentEnteredFolderPayload, + ) +) + + +def payload_if_entered_folder( + *, + document_id: int, + search_space_id: int, + new_folder_id: int | None, + previous_folder_id: int | None, + folder_id_changed: bool, + status_state: str, + document_type: str, + title: str, + connector_id: int | None, + created_by_id: str | None, +) -> dict | None: + """Return a publish payload if this commit represents a folder arrival, else None. + + ``folder_id_changed`` comes from SQLAlchemy attribute history — it is True + only when ``folder_id`` actually changed in this transaction, preventing + spurious events on unrelated saves. + """ + if not folder_id_changed: + return None + if new_folder_id is None: + return None + if status_state != "ready": + return None + + return { + "event_type": EVENT_TYPE, + "search_space_id": search_space_id, + "payload": { + "document_id": document_id, + "folder_id": new_folder_id, + "previous_folder_id": previous_folder_id, + "document_type": document_type, + "title": title, + "connector_id": connector_id, + "created_by_id": created_by_id, + }, + } diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 2339647ea..282bd6034 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -2,6 +2,7 @@ import asyncio import contextlib import hashlib import logging +import sys import time from collections.abc import Awaitable, Callable from dataclasses import dataclass, field @@ -57,6 +58,7 @@ from app.indexing_pipeline.pipeline_logger import ( log_retryable_llm_error, log_unexpected_error, ) +from app.observability import metrics as ot_metrics, otel as ot from app.utils.perf import get_perf_logger @@ -362,6 +364,16 @@ class IndexingPipelineService: ) perf = get_perf_logger() t_index = time.perf_counter() + document_type = ( + document.document_type.value + if getattr(document, "document_type", None) + else None + ) + persist_span_cm = ot.kb_persist_span( + document_type=document_type, + ) + persist_span = persist_span_cm.__enter__() + outcome_status = "failed" try: log_index_started(ctx) document.status = DocumentStatus.processing() @@ -429,34 +441,41 @@ class IndexingPipelineService: time.perf_counter() - t_index, ) log_index_success(ctx, chunk_count=len(chunks)) + outcome_status = "success" await self._enqueue_ai_sort_if_enabled(document) except RETRYABLE_LLM_ERRORS as e: + ot.record_error(persist_span, e) log_retryable_llm_error(ctx, e) + outcome_status = "requeued" await rollback_and_persist_failure( self.session, document, llm_retryable_message(e) ) except PERMANENT_LLM_ERRORS as e: + ot.record_error(persist_span, e) log_permanent_llm_error(ctx, e) await rollback_and_persist_failure( self.session, document, llm_permanent_message(e) ) except RecursionError as e: + ot.record_error(persist_span, e) log_chunking_overflow(ctx, e) await rollback_and_persist_failure( self.session, document, PipelineMessages.CHUNKING_OVERFLOW ) except EMBEDDING_ERRORS as e: + ot.record_error(persist_span, e) log_embedding_error(ctx, e) await rollback_and_persist_failure( self.session, document, embedding_message(e) ) except Exception as e: + ot.record_error(persist_span, e) log_unexpected_error(ctx, e) await rollback_and_persist_failure( self.session, document, safe_exception_message(e) @@ -465,6 +484,17 @@ class IndexingPipelineService: with contextlib.suppress(Exception): await self.session.refresh(document) + with contextlib.suppress(Exception): + persist_span.set_attribute("indexing.status", outcome_status) + ot_metrics.record_indexing_document_duration( + time.perf_counter() - t_index, + document_type=document_type, + ) + ot_metrics.record_indexing_document_outcome( + document_type=document_type, + status=outcome_status, + ) + persist_span_cm.__exit__(*sys.exc_info()) return document async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: diff --git a/surfsense_backend/app/observability/__init__.py b/surfsense_backend/app/observability/__init__.py index dbf082561..a675b1dae 100644 --- a/surfsense_backend/app/observability/__init__.py +++ b/surfsense_backend/app/observability/__init__.py @@ -5,3 +5,5 @@ small wrapper around the optional ``opentelemetry`` instrumentation. The wrapper is a no-op when OTEL is not configured, so importing it from performance-critical paths is safe. """ + +__all__ = ["bootstrap", "metrics", "otel"] diff --git a/surfsense_backend/app/observability/bootstrap.py b/surfsense_backend/app/observability/bootstrap.py new file mode 100644 index 000000000..70008d43d --- /dev/null +++ b/surfsense_backend/app/observability/bootstrap.py @@ -0,0 +1,390 @@ +"""Programmatic OpenTelemetry bootstrap for SurfSense backend processes.""" + +from __future__ import annotations + +import contextlib +import logging +import os +import socket +from importlib import metadata +from typing import Any +from urllib.parse import urlsplit, urlunsplit + +from app.observability import otel + +logger = logging.getLogger(__name__) + +_BOOL_TRUE = {"1", "true", "yes", "on"} + +_TRACES_INITIALIZED = False +_METRICS_INITIALIZED = False +_LOGS_INITIALIZED = False +_FASTAPI_INSTRUMENTED = False +_SQLALCHEMY_INSTRUMENTED = False +_PSYCOPG_INSTRUMENTED = False +_REDIS_INSTRUMENTED = False +_HTTPX_INSTRUMENTED = False +_CELERY_INSTRUMENTED = False + +_TRACER_PROVIDER: Any | None = None +_METER_PROVIDER: Any | None = None + + +def _env_truthy(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in _BOOL_TRUE + + +def is_otel_disabled() -> bool: + """Return true when either SurfSense or OTel's spec kill switch is set.""" + return _env_truthy("SURFSENSE_DISABLE_OTEL") or _env_truthy("OTEL_SDK_DISABLED") + + +def is_otel_configured() -> bool: + """Return true when this process should export OTel signals.""" + return bool( + os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + or os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") + or os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") + ) + + +def _package_version() -> str: + with contextlib.suppress(metadata.PackageNotFoundError): + return metadata.version("surf-new-backend") + return "unknown" + + +def _deployment_environment() -> str: + return os.environ.get("SURFSENSE_ENV", "dev") + + +def _build_resource(): + from opentelemetry.sdk.resources import Resource + + deployment_environment = _deployment_environment() + return Resource.create( + { + "service.name": os.environ.get("OTEL_SERVICE_NAME", "surfsense-backend"), + "service.version": _package_version(), + "service.instance.id": socket.gethostname(), + "deployment.environment.name": deployment_environment, + # Compatibility alias for Grafana onboarding checks that still use + # the older semantic-convention key. + "deployment.environment": deployment_environment, + } + ) + + +def _otlp_protocol() -> str: + return os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc").strip().lower() + + +def _trace_exporter(): + if _otlp_protocol() == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter, + ) + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") + return OTLPSpanExporter(endpoint=endpoint) if endpoint else OTLPSpanExporter() + + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") + return OTLPSpanExporter(endpoint=endpoint) if endpoint else OTLPSpanExporter() + + +def _metric_exporter(): + if _otlp_protocol() == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( + OTLPMetricExporter, + ) + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") + return ( + OTLPMetricExporter(endpoint=endpoint) if endpoint else OTLPMetricExporter() + ) + + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter, + ) + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") + return OTLPMetricExporter(endpoint=endpoint) if endpoint else OTLPMetricExporter() + + +def _safe_instrument(name: str, instrument: Any) -> bool: + try: + instrument() + except Exception: + logger.warning("OpenTelemetry %s instrumentation failed", name, exc_info=True) + return False + return True + + +def _url_without_query(raw_url: Any) -> str | None: + try: + parts = urlsplit(str(raw_url)) + except Exception: + return None + if not parts.scheme or not parts.netloc: + return None + return urlunsplit((parts.scheme, parts.netloc, parts.path or "/", "", "")) + + +def _sanitize_http_span_url(span: Any, request: Any) -> None: + sanitized = _url_without_query(getattr(request, "url", None)) + if not sanitized: + return + with contextlib.suppress(Exception): + # Keep both old and current semantic-convention names safe. The + # collector can drop one later without needing application changes. + span.set_attribute("http.url", sanitized) + span.set_attribute("url.full", sanitized) + + +def _instrument_fastapi(app: Any | None) -> None: + global _FASTAPI_INSTRUMENTED + if app is None or _FASTAPI_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app( + app, + excluded_urls="/health,/ready,/metrics", + ) + + if _safe_instrument("FastAPI", _run): + _FASTAPI_INSTRUMENTED = True + + +def instrument_sqlalchemy_engine(engine: Any) -> None: + """Instrument a SQLAlchemy engine once per process.""" + global _SQLALCHEMY_INSTRUMENTED + if _SQLALCHEMY_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor + + SQLAlchemyInstrumentor().instrument( + engine=getattr(engine, "sync_engine", engine), + enable_commenter=True, + ) + + if _safe_instrument("SQLAlchemy", _run): + _SQLALCHEMY_INSTRUMENTED = True + + +def _instrument_sqlalchemy() -> None: + if _SQLALCHEMY_INSTRUMENTED: + return + with contextlib.suppress(Exception): + from app.db import engine + + instrument_sqlalchemy_engine(engine) + + +def _instrument_psycopg() -> None: + global _PSYCOPG_INSTRUMENTED + if _PSYCOPG_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.psycopg import PsycopgInstrumentor + + PsycopgInstrumentor().instrument() + + if _safe_instrument("psycopg", _run): + _PSYCOPG_INSTRUMENTED = True + + +def _instrument_redis() -> None: + global _REDIS_INSTRUMENTED + if _REDIS_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.redis import RedisInstrumentor + + RedisInstrumentor().instrument() + + if _safe_instrument("Redis", _run): + _REDIS_INSTRUMENTED = True + + +def _instrument_httpx() -> None: + global _HTTPX_INSTRUMENTED + if _HTTPX_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + HTTPXClientInstrumentor().instrument( + request_hook=lambda span, request: _sanitize_http_span_url(span, request), + response_hook=lambda span, request, _response: _sanitize_http_span_url( + span, request + ), + ) + + if _safe_instrument("HTTPX", _run): + _HTTPX_INSTRUMENTED = True + + +def instrument_celery() -> None: + """Instrument Celery producer/consumer hooks once per process.""" + global _CELERY_INSTRUMENTED + if _CELERY_INSTRUMENTED: + return + + def _run() -> None: + from opentelemetry.instrumentation.celery import CeleryInstrumentor + + CeleryInstrumentor().instrument() + + if _safe_instrument("Celery", _run): + _CELERY_INSTRUMENTED = True + + +def _instrument_libraries(app: Any | None) -> None: + _instrument_fastapi(app) + _instrument_sqlalchemy() + _instrument_psycopg() + _instrument_redis() + _instrument_httpx() + instrument_celery() + + +def init_traces(app: Any | None = None) -> None: + """Install the tracer provider, span processor, exporter, and instrumentors.""" + global _TRACER_PROVIDER, _TRACES_INITIALIZED + if _TRACES_INITIALIZED: + _instrument_fastapi(app) + return + + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.sdk.trace.sampling import ALWAYS_ON, ParentBased + + provider = TracerProvider( + resource=_build_resource(), + sampler=ParentBased(ALWAYS_ON), + ) + provider.add_span_processor(BatchSpanProcessor(_trace_exporter())) + + try: + trace.set_tracer_provider(provider) + except Exception: + logger.warning( + "OpenTelemetry tracer provider was already set; reusing existing provider", + exc_info=True, + ) + _TRACER_PROVIDER = trace.get_tracer_provider() + else: + _TRACER_PROVIDER = provider + + _TRACES_INITIALIZED = True + otel.reload_for_tests() + _instrument_libraries(app) + + +def init_metrics() -> None: + """Install the meter provider, metric reader, exporter, and custom gauges.""" + global _METER_PROVIDER, _METRICS_INITIALIZED + if _METRICS_INITIALIZED: + return + + from opentelemetry import metrics + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + + interval_ms = int(os.environ.get("OTEL_METRIC_EXPORT_INTERVAL", "60000")) + reader = PeriodicExportingMetricReader( + _metric_exporter(), + export_interval_millis=interval_ms, + ) + provider = MeterProvider(metric_readers=[reader], resource=_build_resource()) + + try: + metrics.set_meter_provider(provider) + except Exception: + logger.warning( + "OpenTelemetry meter provider was already set; reusing existing provider", + exc_info=True, + ) + _METER_PROVIDER = metrics.get_meter_provider() + else: + _METER_PROVIDER = provider + + _METRICS_INITIALIZED = True + from app.observability.metrics import register_runtime_observables + + register_runtime_observables() + + +def init_logs() -> None: + """Enable trace/span correlation fields on stdlib LogRecords.""" + global _LOGS_INITIALIZED + if _LOGS_INITIALIZED: + return + + def _run() -> None: + from opentelemetry.instrumentation.logging import LoggingInstrumentor + + # Required for stdlib LogRecords to receive otelTraceID/otelSpanID. + # logging.basicConfig is already installed by main.py, so this does not + # take over formatting in normal app startup. + LoggingInstrumentor().instrument(set_logging_format=True) + + if _safe_instrument("logging", _run): + _LOGS_INITIALIZED = True + + +def init_otel( + app: Any | None = None, + *, + traces: bool = True, + metrics: bool = True, + logs: bool = True, +) -> None: + """Initialize OpenTelemetry for a FastAPI or Celery process.""" + if is_otel_disabled() or not is_otel_configured(): + otel.reload_for_tests() + return + + if traces: + init_traces(app) + if metrics: + init_metrics() + if logs: + init_logs() + + +def shutdown_otel(timeout_millis: int = 5000) -> None: + """Best-effort flush and shutdown for installed providers.""" + for provider in (_TRACER_PROVIDER, _METER_PROVIDER): + if provider is None: + continue + with contextlib.suppress(Exception): + provider.force_flush(timeout_millis=timeout_millis) + with contextlib.suppress(Exception): + provider.shutdown() + + +__all__ = [ + "_BOOL_TRUE", + "_build_resource", + "init_logs", + "init_metrics", + "init_otel", + "init_traces", + "instrument_celery", + "instrument_sqlalchemy_engine", + "is_otel_configured", + "is_otel_disabled", + "shutdown_otel", +] diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py new file mode 100644 index 000000000..798a6e2f7 --- /dev/null +++ b/surfsense_backend/app/observability/metrics.py @@ -0,0 +1,684 @@ +"""Custom OpenTelemetry metrics for SurfSense. + +This module owns all SurfSense-specific metric instruments. Callers use the +small helper functions below instead of constructing instruments directly so +attribute names and cardinality stay consistent across the backend. +""" + +from __future__ import annotations + +import contextlib +import gc +import logging +from functools import lru_cache +from importlib import metadata +from typing import Any + +from app.observability import otel + +logger = logging.getLogger(__name__) + +_INSTRUMENTATION_NAME = "surfsense.platform" +_OBSERVABLES_REGISTERED = False +_ERROR_CATEGORY_UNKNOWN = "unknown" + +_ERROR_CATEGORY_HINTS: tuple[tuple[str, tuple[str, ...]], ...] = ( + ("rate_limited", ("ratelimit", "rate_limit", "toomanyrequests", "429")), + ("auth_failed", ("authentication", "auth", "unauthorized", "forbidden")), + ("quota_exhausted", ("quota", "insufficient", "credit", "billing")), + ("timeout", ("timeout", "timedout", "deadline")), + ("network_failed", ("connection", "connect", "network", "dns", "socket")), + ("server_error", ("internalserver", "serviceunavailable", "badgateway", "gateway")), + ("lock_contention", ("lock", "busy", "contention", "alreadyrunning")), + ("unsupported_format", ("unsupported", "format", "filetype")), + ("provider_error", ("provider", "apierror", "apistatus", "badrequest")), +) + + +def _package_version() -> str: + with contextlib.suppress(metadata.PackageNotFoundError): + return metadata.version("surf-new-backend") + return "unknown" + + +def _is_enabled() -> bool: + return otel.is_enabled() + + +def _clean_attrs(attrs: dict[str, Any]) -> dict[str, str | int | float | bool]: + """Drop empty values and coerce low-cardinality attrs to OTel-safe scalars.""" + cleaned: dict[str, str | int | float | bool] = {} + for key, value in attrs.items(): + if value is None: + continue + if isinstance(value, bool | int | float): + cleaned[key] = value + continue + text = str(value) + if text: + cleaned[key] = text + return cleaned + + +def _attrs_with_optional_error_category( + attrs: dict[str, Any], error_category: str | None +) -> dict[str, Any]: + if error_category: + return {**attrs, "error.category": error_category} + return attrs + + +def categorize_exception(exc: BaseException | None) -> str: + """Return a low-cardinality category for an exception.""" + if exc is None: + return _ERROR_CATEGORY_UNKNOWN + haystack = " ".join( + cls.__name__.replace("-", "").replace("_", "").lower() + for cls in type(exc).__mro__ + ) + for category, hints in _ERROR_CATEGORY_HINTS: + if any(hint in haystack for hint in hints): + return category + return _ERROR_CATEGORY_UNKNOWN + + +def parse_celery_task_label(task_name: str | None) -> str: + """Return the operation token from a Celery task name.""" + if not task_name: + return "unknown" + operation = str(task_name).split("_", 1)[0].strip() + return operation or "unknown" + + +def _record(callable_obj: Any, value: int | float, attrs: dict[str, Any]) -> None: + if not _is_enabled(): + return + with contextlib.suppress(Exception): + callable_obj.record(value, _clean_attrs(attrs)) + + +def _add(callable_obj: Any, value: int, attrs: dict[str, Any]) -> None: + if not _is_enabled(): + return + with contextlib.suppress(Exception): + callable_obj.add(value, _clean_attrs(attrs)) + + +@lru_cache(maxsize=1) +def _get_meter(): + from opentelemetry import metrics + + return metrics.get_meter(_INSTRUMENTATION_NAME, _package_version()) + + +@lru_cache(maxsize=1) +def _model_call_duration(): + return _get_meter().create_histogram( + "surfsense.model.call.duration", + unit="ms", + description="Duration of SurfSense LLM model calls.", + ) + + +@lru_cache(maxsize=1) +def _model_token_usage(): + return _get_meter().create_histogram( + "gen_ai.client.token.usage", + unit="{token}", + description="Token usage reported by GenAI model responses.", + ) + + +@lru_cache(maxsize=1) +def _tool_call_duration(): + return _get_meter().create_histogram( + "surfsense.tool.call.duration", + unit="ms", + description="Duration of SurfSense agent tool calls.", + ) + + +@lru_cache(maxsize=1) +def _tool_call_errors(): + return _get_meter().create_counter( + "surfsense.tool.call.errors", + description="Count of SurfSense agent tool call errors.", + ) + + +@lru_cache(maxsize=1) +def _kb_search_duration(): + return _get_meter().create_histogram( + "surfsense.kb.search.duration", + unit="ms", + description="Duration of SurfSense knowledge-base search calls.", + ) + + +@lru_cache(maxsize=1) +def _compaction_runs(): + return _get_meter().create_counter( + "surfsense.compaction.runs", + description="Count of SurfSense conversation compaction runs.", + ) + + +@lru_cache(maxsize=1) +def _permission_asks(): + return _get_meter().create_counter( + "surfsense.permission.asks", + description="Count of SurfSense permission asks.", + ) + + +@lru_cache(maxsize=1) +def _interrupts(): + return _get_meter().create_counter( + "surfsense.interrupt.raised", + description="Count of SurfSense interrupts raised.", + ) + + +@lru_cache(maxsize=1) +def _indexing_document_duration(): + return _get_meter().create_histogram( + "surfsense.indexing.document.duration", + unit="s", + description="Duration of SurfSense document indexing.", + ) + + +@lru_cache(maxsize=1) +def _indexing_document_outcome(): + return _get_meter().create_counter( + "surfsense.indexing.document.outcome", + description="Count of SurfSense document indexing outcomes.", + ) + + +@lru_cache(maxsize=1) +def _connector_sync_duration(): + return _get_meter().create_histogram( + "surfsense.connector.sync.duration", + unit="s", + description="Duration of SurfSense connector sync tasks.", + ) + + +@lru_cache(maxsize=1) +def _connector_sync_outcome(): + return _get_meter().create_counter( + "surfsense.connector.sync.outcome", + description="Count of SurfSense connector sync outcomes.", + ) + + +@lru_cache(maxsize=1) +def _auth_failures(): + return _get_meter().create_counter( + "surfsense.auth.failures", + description="Count of SurfSense authentication failures.", + ) + + +@lru_cache(maxsize=1) +def _rate_limit_rejections(): + return _get_meter().create_counter( + "surfsense.rate_limit.rejections", + description="Count of SurfSense rate-limit rejections.", + ) + + +@lru_cache(maxsize=1) +def _perf_elapsed(): + return _get_meter().create_histogram( + "surfsense.perf.elapsed_ms", + unit="ms", + description="Elapsed time recorded by SurfSense perf timers.", + ) + + +@lru_cache(maxsize=1) +def _chat_request_duration(): + return _get_meter().create_histogram( + "surfsense.chat.request.duration", + unit="ms", + description="Duration of SurfSense streamed chat requests.", + ) + + +@lru_cache(maxsize=1) +def _chat_request_outcome(): + return _get_meter().create_counter( + "surfsense.chat.request.outcome", + description="Count of SurfSense chat request outcomes.", + ) + + +@lru_cache(maxsize=1) +def _subagent_invoke_duration(): + return _get_meter().create_histogram( + "surfsense.subagent.invoke.duration", + unit="ms", + description="Duration of SurfSense subagent invocations.", + ) + + +@lru_cache(maxsize=1) +def _subagent_invoke_outcome(): + return _get_meter().create_counter( + "surfsense.subagent.invoke.outcome", + description="Count of SurfSense subagent invocation outcomes.", + ) + + +@lru_cache(maxsize=1) +def _etl_extract_duration(): + return _get_meter().create_histogram( + "surfsense.etl.extract.duration", + unit="s", + description="Duration of SurfSense ETL extraction.", + ) + + +@lru_cache(maxsize=1) +def _etl_extract_outcome(): + return _get_meter().create_counter( + "surfsense.etl.extract.outcome", + description="Count of SurfSense ETL extraction outcomes.", + ) + + +@lru_cache(maxsize=1) +def _celery_heartbeat_refreshes(): + return _get_meter().create_counter( + "surfsense.celery.heartbeat.refreshes", + description="Count of SurfSense Celery heartbeat refreshes.", + ) + + +@lru_cache(maxsize=1) +def _celery_heartbeat_failures(): + return _get_meter().create_counter( + "surfsense.celery.heartbeat.failures", + description="Count of SurfSense Celery heartbeat failures.", + ) + + +@lru_cache(maxsize=1) +def _celery_queue_latency(): + return _get_meter().create_histogram( + "surfsense.celery.queue.latency", + unit="s", + description="Time SurfSense Celery tasks spend waiting in queue.", + ) + + +def record_model_call_duration( + duration_ms: float, *, model: str | None, provider: str | None +) -> None: + _record( + _model_call_duration(), + duration_ms, + { + "gen_ai.request.model": model, + "gen_ai.provider.name": provider, + }, + ) + + +def record_model_token_usage( + *, + input_tokens: int | None, + output_tokens: int | None, + model: str | None, + provider: str | None, +) -> None: + base = { + "gen_ai.request.model": model, + "gen_ai.provider.name": provider, + "gen_ai.operation.name": "chat", + } + if input_tokens is not None: + _record( + _model_token_usage(), + int(input_tokens), + {**base, "gen_ai.token.type": "input"}, + ) + if output_tokens is not None: + _record( + _model_token_usage(), + int(output_tokens), + {**base, "gen_ai.token.type": "output"}, + ) + + +def record_tool_call_duration(duration_ms: float, *, tool_name: str) -> None: + _record(_tool_call_duration(), duration_ms, {"tool.name": tool_name}) + + +def record_tool_call_error(*, tool_name: str) -> None: + _add(_tool_call_errors(), 1, {"tool.name": tool_name}) + + +def record_kb_search_duration( + duration_ms: float, *, search_space_id: int | None, surface: str +) -> None: + _record( + _kb_search_duration(), + duration_ms, + {"search_space.id": search_space_id, "search.surface": surface}, + ) + + +def record_compaction_run(*, reason: str | None) -> None: + _add(_compaction_runs(), 1, {"compaction.reason": reason or "unknown"}) + + +def record_permission_ask(*, permission: str) -> None: + _add(_permission_asks(), 1, {"permission.permission": permission}) + + +def record_interrupt(*, interrupt_type: str) -> None: + _add(_interrupts(), 1, {"interrupt.type": interrupt_type}) + + +def record_indexing_document_duration( + duration_s: float, *, document_type: str | None +) -> None: + _record( + _indexing_document_duration(), + duration_s, + {"document.type": document_type or "unknown"}, + ) + + +def record_indexing_document_outcome(*, document_type: str | None, status: str) -> None: + _add( + _indexing_document_outcome(), + 1, + {"document.type": document_type or "unknown", "status": status}, + ) + + +def record_connector_sync_duration( + duration_s: float, *, connector_type: str | None +) -> None: + _record( + _connector_sync_duration(), + duration_s, + {"connector.type": connector_type or "unknown"}, + ) + + +def record_connector_sync_outcome( + *, connector_type: str | None, status: str, error_category: str | None = None +) -> None: + _add( + _connector_sync_outcome(), + 1, + _attrs_with_optional_error_category( + {"connector.type": connector_type or "unknown", "status": status}, + error_category, + ), + ) + + +def record_auth_failure(*, reason: str) -> None: + _add(_auth_failures(), 1, {"reason": reason}) + + +def record_rate_limit_rejection(*, scope: str) -> None: + _add(_rate_limit_rejections(), 1, {"scope": scope}) + + +def record_perf_elapsed(duration_ms: float, *, label: str) -> None: + _record(_perf_elapsed(), duration_ms, {"label": label}) + + +def record_chat_request_duration( + duration_ms: float, + *, + flow: str, + outcome: str, + agent_mode: str | None = None, +) -> None: + _record( + _chat_request_duration(), + duration_ms, + {"chat.flow": flow, "outcome": outcome, "agent.mode": agent_mode}, + ) + + +def record_chat_request_outcome( + *, + flow: str, + outcome: str, + agent_mode: str | None = None, + error_category: str | None = None, +) -> None: + _add( + _chat_request_outcome(), + 1, + _attrs_with_optional_error_category( + {"chat.flow": flow, "outcome": outcome, "agent.mode": agent_mode}, + error_category, + ), + ) + + +def record_subagent_invoke_duration( + duration_ms: float, + *, + subagent_type: str, + path: str | None, + outcome: str, +) -> None: + _record( + _subagent_invoke_duration(), + duration_ms, + { + "subagent.type": subagent_type, + "subagent.path": path or "unknown", + "outcome": outcome, + }, + ) + + +def record_subagent_invoke_outcome( + *, + subagent_type: str, + path: str | None, + outcome: str, +) -> None: + _add( + _subagent_invoke_outcome(), + 1, + { + "subagent.type": subagent_type, + "subagent.path": path or "unknown", + "outcome": outcome, + }, + ) + + +def record_etl_extract_duration( + duration_s: float, + *, + etl_service: str | None, + content_type: str | None, + status: str, +) -> None: + _record( + _etl_extract_duration(), + duration_s, + { + "etl.service": etl_service or "unknown", + "content.type": content_type or "unknown", + "status": status, + }, + ) + + +def record_etl_extract_outcome( + *, + etl_service: str | None, + content_type: str | None, + status: str, + error_category: str | None = None, +) -> None: + _add( + _etl_extract_outcome(), + 1, + _attrs_with_optional_error_category( + { + "etl.service": etl_service or "unknown", + "content.type": content_type or "unknown", + "status": status, + }, + error_category, + ), + ) + + +def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: + _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) + + +def record_celery_heartbeat_failure(*, heartbeat_type: str) -> None: + _add(_celery_heartbeat_failures(), 1, {"heartbeat.type": heartbeat_type}) + + +def record_celery_queue_latency( + duration_s: float, + *, + task_name: str | None, + queue: str | None, + scheduled: bool, + operation: str | None, +) -> None: + _record( + _celery_queue_latency(), + duration_s, + { + "task.name": task_name or "unknown", + "task.queue": queue or "unknown", + "task.scheduled": bool(scheduled), + "operation": operation or "unknown", + }, + ) + + +def _runtime_snapshot_value(key: str, transform: Any = None) -> list[Any]: + from opentelemetry.metrics import Observation + + from app.utils.perf import system_snapshot + + snap = system_snapshot() + value = snap.get(key) + if not isinstance(value, int | float) or value < 0: + return [] + if transform is not None: + value = transform(value) + return [Observation(value)] + + +def _observe_gc_collections(_options: Any) -> list[Any]: + from opentelemetry.metrics import Observation + + return [ + Observation(count, {"generation": str(generation)}) + for generation, count in enumerate(gc.get_count()) + ] + + +def register_runtime_observables() -> None: + """Register process/runtime observable gauges once per process.""" + global _OBSERVABLES_REGISTERED + if _OBSERVABLES_REGISTERED or not _is_enabled(): + return + + meter = _get_meter() + try: + # Each callback returns the value for a single gauge except GC, whose + # callback carries a generation attribute. + meter.create_observable_gauge( + "process.runtime.cpython.memory.rss", + callbacks=[ + lambda _options: _runtime_snapshot_value( + "rss_mb", lambda v: float(v) * 1024 * 1024 + ) + ], + unit="By", + description="Resident set size of the SurfSense backend process.", + ) + meter.create_observable_gauge( + "process.runtime.cpython.cpu.utilization", + callbacks=[ + lambda _options: _runtime_snapshot_value( + "cpu_percent", lambda v: float(v) / 100.0 + ) + ], + unit="1", + description="CPU utilization of the SurfSense backend process.", + ) + meter.create_observable_gauge( + "process.runtime.cpython.threads", + callbacks=[lambda _options: _runtime_snapshot_value("threads")], + unit="{thread}", + description="Thread count of the SurfSense backend process.", + ) + meter.create_observable_gauge( + "process.runtime.cpython.open_fds", + callbacks=[lambda _options: _runtime_snapshot_value("open_fds")], + unit="{fd}", + description="Open file descriptor count of the SurfSense backend process.", + ) + meter.create_observable_gauge( + "python.asyncio.tasks", + callbacks=[lambda _options: _runtime_snapshot_value("asyncio_tasks")], + unit="{task}", + description="Live asyncio task count in the current event loop.", + ) + meter.create_observable_gauge( + "process.runtime.cpython.gc.collections", + callbacks=[_observe_gc_collections], + unit="{collection}", + description="CPython GC counters by generation.", + ) + except Exception: + logger.warning("Failed to register OTel runtime observables", exc_info=True) + return + + _OBSERVABLES_REGISTERED = True + + +__all__ = [ + "categorize_exception", + "parse_celery_task_label", + "record_auth_failure", + "record_celery_heartbeat_failure", + "record_celery_heartbeat_refresh", + "record_celery_queue_latency", + "record_chat_request_duration", + "record_chat_request_outcome", + "record_compaction_run", + "record_connector_sync_duration", + "record_connector_sync_outcome", + "record_etl_extract_duration", + "record_etl_extract_outcome", + "record_indexing_document_duration", + "record_indexing_document_outcome", + "record_interrupt", + "record_kb_search_duration", + "record_model_call_duration", + "record_model_token_usage", + "record_perf_elapsed", + "record_permission_ask", + "record_rate_limit_rejection", + "record_subagent_invoke_duration", + "record_subagent_invoke_outcome", + "record_tool_call_duration", + "record_tool_call_error", + "register_runtime_observables", +] diff --git a/surfsense_backend/app/observability/otel.py b/surfsense_backend/app/observability/otel.py index 6791ab499..ad2178f39 100644 --- a/surfsense_backend/app/observability/otel.py +++ b/surfsense_backend/app/observability/otel.py @@ -66,6 +66,8 @@ def _resolve_enabled() -> bool: # Honor an explicit kill-switch first. if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}: return False + if os.environ.get("OTEL_SDK_DISABLED", "").lower() in {"1", "true", "yes", "on"}: + return False # Treat a configured endpoint as the canonical "OTel is wired up" signal. if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): return True @@ -90,6 +92,48 @@ def is_enabled() -> bool: return _ENABLED +def _clean_event_attrs(attrs: dict[str, Any]) -> dict[str, str | int | float | bool]: + """Coerce event attributes to OTel-safe scalar values.""" + cleaned: dict[str, str | int | float | bool] = {} + for key, value in attrs.items(): + if value is None: + continue + if isinstance(value, bool | int | float): + cleaned[key] = value + continue + text = str(value) + if text: + cleaned[key] = text + return cleaned + + +def add_event(name: str, attributes: dict[str, Any] | None = None) -> None: + """Attach an event to the current active span. + + This is intentionally no-op and exception-safe when OTel is disabled, + unavailable, or no span is currently recording. + """ + if not _ENABLED or _ot_trace is None: + return + with contextlib.suppress(Exception): + sp = _ot_trace.get_current_span() + if sp is None or not sp.is_recording(): + return + sp.add_event( + name, + attributes=_clean_event_attrs(attributes) if attributes else None, + ) + + +def record_error(span_obj: Any, exc: BaseException) -> None: + """Record an exception and mark a span as errored without re-raising.""" + if not _ENABLED: + return + with contextlib.suppress(Exception): + span_obj.record_exception(exc) + span_obj.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc))) + + def _get_tracer(): if not _OTEL_AVAILABLE: return None @@ -198,8 +242,11 @@ def model_call_span( attrs: dict[str, Any] = {} if model_id: attrs["model.id"] = model_id + attrs["gen_ai.request.model"] = model_id if provider: attrs["model.provider"] = provider + attrs["gen_ai.provider.name"] = provider + attrs["gen_ai.operation.name"] = "chat" if extra: attrs.update(extra) return span("model.call", attributes=attrs) @@ -239,6 +286,152 @@ def kb_persist_span( return span("kb.persist", attributes=attrs) +def chat_request_span( + *, + chat_id: int | None = None, + search_space_id: int | None = None, + flow: str | None = None, + request_id: str | None = None, + turn_id: str | None = None, + filesystem_mode: str | None = None, + client_platform: str | None = None, + agent_mode: str | None = None, + extra: dict[str, Any] | None = None, +): + """Parent span for a single streamed chat or resume turn.""" + attrs: dict[str, Any] = {} + if chat_id is not None: + attrs["chat.id"] = int(chat_id) + if search_space_id is not None: + attrs["search_space.id"] = int(search_space_id) + if flow: + attrs["chat.flow"] = flow + if request_id: + attrs["request.id"] = request_id + if turn_id: + attrs["turn.id"] = turn_id + if filesystem_mode: + attrs["filesystem.mode"] = filesystem_mode + if client_platform: + attrs["client.platform"] = client_platform + if agent_mode: + attrs["agent.mode"] = agent_mode + if extra: + attrs.update(extra) + return span("chat.request", attributes=attrs) + + +def subagent_invoke_span( + *, + subagent_type: str, + path: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around invoking a delegated subagent from the main agent.""" + attrs: dict[str, Any] = {"subagent.type": subagent_type} + if path: + attrs["subagent.path"] = path + if extra: + attrs.update(extra) + return span("subagent.invoke", attributes=attrs) + + +def connector_sync_span( + *, + connector_type: str | None, + extra: dict[str, Any] | None = None, +): + """Business-level span around connector indexing task execution.""" + attrs: dict[str, Any] = {"connector.type": connector_type or "unknown"} + if extra: + attrs.update(extra) + return span("connector.sync", attributes=attrs) + + +def etl_extract_span( + *, + content_type: str | None = None, + file_extension: str | None = None, + processing_mode: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around top-level ETL extraction for a file.""" + attrs: dict[str, Any] = {} + if content_type: + attrs["content.type"] = content_type + if file_extension: + attrs["file.extension"] = file_extension + if processing_mode: + attrs["processing.mode"] = processing_mode + if extra: + attrs.update(extra) + return span("etl.extract", attributes=attrs) + + +def etl_parse_span( + *, + etl_service: str | None, + content_type: str | None = None, + file_extension: str | None = None, + processing_mode: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around a concrete ETL parser/backend call.""" + attrs: dict[str, Any] = {"etl.service": etl_service or "unknown"} + if content_type: + attrs["content.type"] = content_type + if file_extension: + attrs["file.extension"] = file_extension + if processing_mode: + attrs["processing.mode"] = processing_mode + if extra: + attrs.update(extra) + return span("etl.parse", attributes=attrs) + + +def etl_ocr_span( + *, + etl_service: str | None, + file_extension: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around OCR extraction from image content.""" + attrs: dict[str, Any] = {"etl.service": etl_service or "unknown"} + if file_extension: + attrs["file.extension"] = file_extension + if extra: + attrs.update(extra) + return span("etl.ocr", attributes=attrs) + + +def etl_picture_describe_span( + *, + image_count: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around describing embedded images in a document.""" + attrs: dict[str, Any] = {} + if image_count is not None: + attrs["image.count"] = int(image_count) + if extra: + attrs.update(extra) + return span("etl.picture.describe", attributes=attrs) + + +def etl_picture_ocr_span( + *, + file_extension: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around per-image OCR during picture description.""" + attrs: dict[str, Any] = {} + if file_extension: + attrs["file.extension"] = file_extension + if extra: + attrs.update(extra) + return span("etl.picture.ocr", attributes=attrs) + + def compaction_span( *, reason: str | None = None, @@ -301,14 +494,24 @@ def reload_for_tests() -> bool: __all__ = [ + "add_event", + "chat_request_span", "compaction_span", + "connector_sync_span", + "etl_extract_span", + "etl_ocr_span", + "etl_parse_span", + "etl_picture_describe_span", + "etl_picture_ocr_span", "interrupt_span", "is_enabled", "kb_persist_span", "kb_search_span", "model_call_span", "permission_asked_span", + "record_error", "reload_for_tests", "span", + "subagent_invoke_span", "tool_call_span", ] diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index e32c6c43d..47f7fe6b1 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -1,13 +1,51 @@ import asyncio import contextlib +import functools import time from datetime import datetime +from app.observability import metrics as ot_metrics, otel as ot from app.utils.perf import get_perf_logger _MAX_FETCH_CHUNKS_PER_DOC = 20 +def _instrument_search(mode: str): + def _decorator(func): + @functools.wraps(func) + async def _wrapper( + self, query_text: str, top_k: int, search_space_id: int, *args, **kwargs + ): + t0 = time.perf_counter() + with ot.kb_search_span( + search_space_id=search_space_id, + query_chars=len(query_text), + extra={"search.surface": "chunks", "search.mode": mode}, + ) as sp: + try: + result = await func( + self, query_text, top_k, search_space_id, *args, **kwargs + ) + except Exception: + ot_metrics.record_kb_search_duration( + (time.perf_counter() - t0) * 1000, + search_space_id=search_space_id, + surface="chunks", + ) + raise + sp.set_attribute("result.count", len(result)) + ot_metrics.record_kb_search_duration( + (time.perf_counter() - t0) * 1000, + search_space_id=search_space_id, + surface="chunks", + ) + return result + + return _wrapper + + return _decorator + + class ChucksHybridSearchRetriever: def __init__(self, db_session): """ @@ -18,6 +56,7 @@ class ChucksHybridSearchRetriever: """ self.db_session = db_session + @_instrument_search("vector") async def vector_search( self, query_text: str, @@ -88,6 +127,7 @@ class ChucksHybridSearchRetriever: return chunks + @_instrument_search("full_text") async def full_text_search( self, query_text: str, @@ -153,6 +193,7 @@ class ChucksHybridSearchRetriever: return chunks + @_instrument_search("hybrid") async def hybrid_search( self, query_text: str, diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index 3eabdb004..9ce86d404 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -1,12 +1,50 @@ import contextlib +import functools import time from datetime import datetime +from app.observability import metrics as ot_metrics, otel as ot from app.utils.perf import get_perf_logger _MAX_FETCH_CHUNKS_PER_DOC = 20 +def _instrument_search(mode: str): + def _decorator(func): + @functools.wraps(func) + async def _wrapper( + self, query_text: str, top_k: int, search_space_id: int, *args, **kwargs + ): + t0 = time.perf_counter() + with ot.kb_search_span( + search_space_id=search_space_id, + query_chars=len(query_text), + extra={"search.surface": "documents", "search.mode": mode}, + ) as sp: + try: + result = await func( + self, query_text, top_k, search_space_id, *args, **kwargs + ) + except Exception: + ot_metrics.record_kb_search_duration( + (time.perf_counter() - t0) * 1000, + search_space_id=search_space_id, + surface="documents", + ) + raise + sp.set_attribute("result.count", len(result)) + ot_metrics.record_kb_search_duration( + (time.perf_counter() - t0) * 1000, + search_space_id=search_space_id, + surface="documents", + ) + return result + + return _wrapper + + return _decorator + + class DocumentHybridSearchRetriever: def __init__(self, db_session): """ @@ -17,6 +55,7 @@ class DocumentHybridSearchRetriever: """ self.db_session = db_session + @_instrument_search("vector") async def vector_search( self, query_text: str, @@ -81,6 +120,7 @@ class DocumentHybridSearchRetriever: return documents + @_instrument_search("full_text") async def full_text_search( self, query_text: str, @@ -145,6 +185,7 @@ class DocumentHybridSearchRetriever: return documents + @_instrument_search("hybrid") async def hybrid_search( self, query_text: str, diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index ec4d1650f..8373f13c3 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -1,5 +1,7 @@ from fastapi import APIRouter +from app.automations.api import router as automations_router + from .agent_action_log_route import router as agent_action_log_router from .agent_flags_route import router as agent_flags_router from .agent_permissions_route import router as agent_permissions_router @@ -53,7 +55,6 @@ from .search_source_connectors_routes import router as search_source_connectors_ from .search_spaces_routes import router as search_spaces_router from .slack_add_connector_route import router as slack_add_connector_router from .stripe_routes import router as stripe_router -from .surfsense_docs_routes import router as surfsense_docs_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router @@ -106,7 +107,6 @@ router.include_router(new_llm_config_router) # LLM configs with prompt configur router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks -router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync router.include_router( mcp_oauth_router @@ -119,3 +119,4 @@ router.include_router(youtube_router) # YouTube playlist resolution router.include_router(prompts_router) router.include_router(memory_router) # User personal memory (memory.md style) router.include_router(team_memory_router) # Search-space team memory +router.include_router(automations_router) # Automations CRUD + run history diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index f9d694e5a..eb952e684 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -351,10 +351,9 @@ async def stream_anonymous_chat( async def _generate(): from langchain_core.messages import AIMessage, HumanMessage - from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent + from app.agents.new_chat.anonymous_agent import create_anonymous_chat_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.db import shielded_async_session - from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.services.token_tracking_service import start_turn from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events @@ -363,24 +362,23 @@ async def stream_anonymous_chat( streaming_service = VercelStreamingService() try: - async with shielded_async_session() as session: - connector_service = ConnectorService(session, search_space_id=None) + async with shielded_async_session(): checkpointer = await get_checkpointer() anon_thread_id = f"anon-{session_id}-{request_id}" - agent = await create_surfsense_deep_agent( + # Load the optional uploaded document as read-only context. + anon_doc = await _load_anon_document(session_id) + + # Minimal Q/A agent: web_search only (when enabled), no + # filesystem / persistence / subagents. The uploaded document + # is injected into the system prompt as read-only context. + agent = await create_anonymous_chat_agent( llm=llm, - search_space_id=0, - db_session=session, - connector_service=connector_service, checkpointer=checkpointer, - user_id=None, - thread_id=None, - agent_config=agent_config, - enabled_tools=list(enabled_for_agent), - disabled_tools=None, anon_session_id=session_id, + anon_doc=anon_doc, + enable_web_search="web_search" in enabled_for_agent, ) langchain_messages = [] @@ -396,7 +394,6 @@ async def stream_anonymous_chat( input_state = { "messages": langchain_messages, - "search_space_id": 0, } langgraph_config = { @@ -500,6 +497,38 @@ ANON_ALLOWED_EXTENSIONS = PLAINTEXT_EXTENSIONS | DIRECT_CONVERT_EXTENSIONS ANON_DOC_REDIS_PREFIX = "anon:doc:" +async def _load_anon_document(session_id: str) -> dict[str, Any] | None: + """Read the anonymous session's uploaded document from Redis. + + Returns ``{"title", "content"}`` for read-only injection into the agent's + system prompt, or ``None`` when nothing was uploaded for this session. + """ + import json as _json + + import redis.asyncio as aioredis + + redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True) + redis_key = f"{ANON_DOC_REDIS_PREFIX}{session_id}" + try: + data = await redis_client.get(redis_key) + if not data: + return None + payload = _json.loads(data) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to load anonymous document from Redis: %s", exc) + return None + finally: + await redis_client.aclose() + + content = str(payload.get("content") or "") + if not content: + return None + return { + "title": str(payload.get("filename") or "uploaded_document"), + "content": content, + } + + class AnonDocResponse(BaseModel): filename: str size_bytes: int diff --git a/surfsense_backend/app/routes/folders_routes.py b/surfsense_backend/app/routes/folders_routes.py index 2dc9bceac..dca55f31e 100644 --- a/surfsense_backend/app/routes/folders_routes.py +++ b/surfsense_backend/app/routes/folders_routes.py @@ -525,11 +525,8 @@ async def bulk_move_documents( detail="Cannot move documents to a folder in a different search space", ) - await session.execute( - Document.__table__.update() - .where(Document.id.in_(request.document_ids)) - .values(folder_id=request.folder_id) - ) + for doc in documents: + doc.folder_id = request.folder_id await session.commit() return {"message": f"{len(request.document_ids)} documents moved successfully"} diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 44fc1c392..63b7732a9 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1771,6 +1771,11 @@ async def handle_new_chat( if request.mentioned_documents else None ) + mentioned_connectors_payload = ( + [doc.model_dump() for doc in request.mentioned_connectors] + if request.mentioned_connectors + else None + ) return StreamingResponse( stream_new_chat( @@ -1780,8 +1785,9 @@ async def handle_new_chat( user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, - mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, mentioned_folder_ids=request.mentioned_folder_ids, + mentioned_connector_ids=request.mentioned_connector_ids, + mentioned_connectors=mentioned_connectors_payload, mentioned_documents=mentioned_documents_payload, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, @@ -2258,6 +2264,11 @@ async def regenerate_response( if request.mentioned_documents else None ) + mentioned_connectors_payload = ( + [doc.model_dump() for doc in request.mentioned_connectors] + if request.mentioned_connectors + else None + ) try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), @@ -2266,8 +2277,9 @@ async def regenerate_response( user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, - mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, mentioned_folder_ids=request.mentioned_folder_ids, + mentioned_connector_ids=request.mentioned_connector_ids, + mentioned_connectors=mentioned_connectors_payload, mentioned_documents=mentioned_documents_payload, checkpoint_id=target_checkpoint_id, needs_history_bootstrap=thread.needs_history_bootstrap, diff --git a/surfsense_backend/app/routes/rbac_routes.py b/surfsense_backend/app/routes/rbac_routes.py index 38ae31269..3b91e456d 100644 --- a/surfsense_backend/app/routes/rbac_routes.py +++ b/surfsense_backend/app/routes/rbac_routes.py @@ -107,6 +107,12 @@ PERMISSION_DESCRIPTIONS = { "settings:view": "View search space settings", "settings:update": "Modify search space settings", "settings:delete": "Delete the entire search space", + # Automations + "automations:create": "Create automations from chat or JSON", + "automations:read": "View automations, their triggers, and run history", + "automations:update": "Edit automations and manage their triggers", + "automations:delete": "Remove automations from the search space", + "automations:execute": "Manually fire automations", # Full access "*": "Full access to all features and settings", } diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 1338fe16b..3060fdf4a 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -43,6 +43,7 @@ from app.db import ( async_session_maker, get_async_session, ) +from app.observability import metrics as ot_metrics, otel as ot from app.schemas import ( GoogleDriveIndexRequest, MCPConnectorCreate, @@ -104,7 +105,9 @@ async def _run_indexing_heartbeat_loop(notification_id: int) -> None: await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL) try: get_heartbeat_redis_client().setex(key, HEARTBEAT_TTL_SECONDS, "alive") + ot_metrics.record_celery_heartbeat_refresh(heartbeat_type="connector") except Exception as e: + ot_metrics.record_celery_heartbeat_failure(heartbeat_type="connector") logger.warning( f"Failed to refresh Redis heartbeat for notification " f"{notification_id}: {e}" @@ -1243,6 +1246,12 @@ async def _persist_auth_expired(session: AsyncSession, connector_id: int) -> Non """Flag a connector as auth_expired so the frontend shows a re-auth prompt.""" from sqlalchemy.orm.attributes import flag_modified + ot.add_event( + "connector.auth.expired", + { + "error.category": "auth_failed", + }, + ) try: result = await session.execute( select(SearchSourceConnector).where( @@ -1302,6 +1311,13 @@ async def _run_indexing_with_notifications( try: connector_lock_acquired = acquire_connector_indexing_lock(connector_id) if not connector_lock_acquired: + ot.add_event( + "connector.sync.skipped", + { + "skip.reason": "lock_contention", + "error.category": "lock_contention", + }, + ) logger.info( f"Skipping indexing for connector {connector_id} " "(another worker already holds Redis connector lock)" @@ -1338,7 +1354,13 @@ async def _run_indexing_with_notifications( get_heartbeat_redis_client().setex( heartbeat_key, HEARTBEAT_TTL_SECONDS, "0" ) + ot_metrics.record_celery_heartbeat_refresh( + heartbeat_type="connector" + ) except Exception as e: + ot_metrics.record_celery_heartbeat_failure( + heartbeat_type="connector" + ) logger.warning(f"Failed to set initial Redis heartbeat: {e}") # Start a background coroutine that refreshes the @@ -1366,6 +1388,15 @@ async def _run_indexing_with_notifications( ) -> None: """Callback to update notification during API retries (rate limits, etc.)""" nonlocal notification + ot.add_event( + "connector.retry.scheduled", + { + "retry.reason": retry_reason, + "retry.attempt": attempt, + "retry.max": max_attempts, + "retry.delay_ms": int(wait_seconds * 1000), + }, + ) if notification: try: await session.refresh(notification) @@ -1397,8 +1428,14 @@ async def _run_indexing_with_notifications( get_heartbeat_redis_client().setex( heartbeat_key, HEARTBEAT_TTL_SECONDS, str(indexed_count) ) + ot_metrics.record_celery_heartbeat_refresh( + heartbeat_type="connector" + ) except Exception as e: # Don't let Redis errors break the indexing + ot_metrics.record_celery_heartbeat_failure( + heartbeat_type="connector" + ) logger.warning(f"Failed to set Redis heartbeat: {e}") try: diff --git a/surfsense_backend/app/routes/surfsense_docs_routes.py b/surfsense_backend/app/routes/surfsense_docs_routes.py deleted file mode 100644 index 0d5428dec..000000000 --- a/surfsense_backend/app/routes/surfsense_docs_routes.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Routes for Surfsense documentation. - -These endpoints support the citation system for Surfsense docs, -allowing the frontend to fetch document details when a user clicks -on a [citation:doc-XXX] link. -""" - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from app.db import ( - SurfsenseDocsChunk, - SurfsenseDocsDocument, - User, - get_async_session, -) -from app.schemas import PaginatedResponse -from app.schemas.surfsense_docs import ( - SurfsenseDocsChunkRead, - SurfsenseDocsDocumentRead, - SurfsenseDocsDocumentWithChunksRead, -) -from app.users import current_active_user -from app.utils.surfsense_docs import surfsense_docs_public_url - -router = APIRouter() - - -@router.get( - "/surfsense-docs/by-chunk/{chunk_id}", - response_model=SurfsenseDocsDocumentWithChunksRead, -) -async def get_surfsense_doc_by_chunk_id( - chunk_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Retrieves a Surfsense documentation document based on a chunk ID. - - This endpoint is used by the frontend to resolve [citation:doc-XXX] links. - """ - try: - # Get the chunk - chunk_result = await session.execute( - select(SurfsenseDocsChunk).filter(SurfsenseDocsChunk.id == chunk_id) - ) - chunk = chunk_result.scalars().first() - - if not chunk: - raise HTTPException( - status_code=404, - detail=f"Surfsense docs chunk with id {chunk_id} not found", - ) - - # Get the associated document with all its chunks - document_result = await session.execute( - select(SurfsenseDocsDocument) - .options(selectinload(SurfsenseDocsDocument.chunks)) - .filter(SurfsenseDocsDocument.id == chunk.document_id) - ) - document = document_result.scalars().first() - - if not document: - raise HTTPException( - status_code=404, - detail="Surfsense docs document not found", - ) - - # Sort chunks by ID - sorted_chunks = sorted(document.chunks, key=lambda x: x.id) - - return SurfsenseDocsDocumentWithChunksRead( - id=document.id, - title=document.title, - source=document.source, - public_url=surfsense_docs_public_url(document.source), - content=document.content, - chunks=[ - SurfsenseDocsChunkRead(id=c.id, content=c.content) - for c in sorted_chunks - ], - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Failed to retrieve Surfsense documentation: {e!s}", - ) from e - - -@router.get( - "/surfsense-docs", - response_model=PaginatedResponse[SurfsenseDocsDocumentRead], -) -async def list_surfsense_docs( - page: int = 0, - page_size: int = 50, - title: str | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - List all Surfsense documentation documents. - - Args: - page: Zero-based page index. - page_size: Number of items per page (default: 50). - title: Optional title filter (case-insensitive substring match). - session: Database session (injected). - user: Current authenticated user (injected). - - Returns: - PaginatedResponse[SurfsenseDocsDocumentRead]: Paginated list of Surfsense docs. - """ - try: - # Base query - query = select(SurfsenseDocsDocument) - count_query = select(func.count()).select_from(SurfsenseDocsDocument) - - # Filter by title if provided - if title and title.strip(): - query = query.filter(SurfsenseDocsDocument.title.ilike(f"%{title}%")) - count_query = count_query.filter( - SurfsenseDocsDocument.title.ilike(f"%{title}%") - ) - - # Get total count - total_result = await session.execute(count_query) - total = total_result.scalar() or 0 - - # Calculate offset - offset = page * page_size - - # Get paginated results - result = await session.execute( - query.order_by(SurfsenseDocsDocument.title).offset(offset).limit(page_size) - ) - docs = result.scalars().all() - - # Convert to response format - items = [ - SurfsenseDocsDocumentRead( - id=doc.id, - title=doc.title, - source=doc.source, - public_url=surfsense_docs_public_url(doc.source), - content=doc.content, - created_at=doc.created_at, - updated_at=doc.updated_at, - ) - for doc in docs - ] - - has_more = (offset + len(items)) < total - - return PaginatedResponse( - items=items, - total=total, - page=page, - page_size=page_size, - has_more=has_more, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Failed to list Surfsense documentation: {e!s}", - ) from e diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c5315cce5..ab95f9b6b 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -203,13 +203,11 @@ class NewChatUserImagePart(BaseModel): class MentionedDocumentInfo(BaseModel): """Display metadata for a single ``@``-mention chip. - Carries either a knowledge-base document or a knowledge-base folder - (discriminated by ``kind``). The full triple - ``{id, title, document_type}`` is forwarded by the frontend mention - chip so the server can embed it in the persisted user message - ``ContentPart[]`` (single ``mentioned-documents`` part). The - history loader then renders the chips on reload without an extra - fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape. + Carries a knowledge-base document, knowledge-base folder, or + connected account (discriminated by ``kind``). Each kind uses its + real identity fields: docs carry ``document_type``, folders carry + only their folder id/title, and connectors carry ``connector_type`` + plus account metadata. ``kind`` defaults to ``"doc"`` so legacy clients and persisted rows that predate folder mentions deserialise unchanged. @@ -217,18 +215,18 @@ class MentionedDocumentInfo(BaseModel): id: int title: str = Field(..., min_length=1, max_length=500) - document_type: str = Field(..., min_length=1, max_length=100) - kind: Literal["doc", "folder"] = Field( + document_type: str | None = Field(default=None, min_length=1, max_length=100) + kind: Literal["doc", "folder", "connector"] = Field( default="doc", description=( "Discriminator for the chip's referent: ``doc`` is a " "knowledge-base ``Document`` row, ``folder`` is a " - "knowledge-base ``Folder`` row. Folders carry the sentinel " - "``document_type='FOLDER'`` to keep the frontend dedup key " - "``(kind:document_type:id)`` from colliding doc and folder " - "ids that happen to share an integer value." + "knowledge-base ``Folder`` row, and ``connector`` is a " + "concrete connected account." ), ) + connector_type: str | None = Field(default=None, max_length=100) + account_name: str | None = Field(default=None, max_length=255) class NewChatRequest(BaseModel): @@ -241,9 +239,6 @@ class NewChatRequest(BaseModel): mentioned_document_ids: list[int] | None = ( None # Optional document IDs mentioned with @ in the chat ) - mentioned_surfsense_doc_ids: list[int] | None = ( - None # Optional SurfSense documentation IDs mentioned with @ in the chat - ) mentioned_folder_ids: list[int] | None = Field( default=None, description=( @@ -266,6 +261,18 @@ class NewChatRequest(BaseModel): "a mentioned-documents part." ), ) + mentioned_connector_ids: list[int] | None = Field( + default=None, + description="Optional concrete connector account IDs the user @-mentioned.", + ) + mentioned_connectors: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display/context metadata for selected connector accounts. " + "Kept separate from document/folder id arrays so tools can " + "prefer the exact account the user selected." + ), + ) disabled_tools: list[str] | None = ( None # Optional list of tool names the user has disabled from the UI ) @@ -316,7 +323,6 @@ class RegenerateRequest(BaseModel): None # New user query (for edit). None = reload with same query ) mentioned_document_ids: list[int] | None = None - mentioned_surfsense_doc_ids: list[int] | None = None mentioned_folder_ids: list[int] | None = Field( default=None, description=( @@ -335,6 +341,8 @@ class RegenerateRequest(BaseModel): "new user message. None means no chip metadata." ), ) + mentioned_connector_ids: list[int] | None = None + mentioned_connectors: list[MentionedDocumentInfo] | None = None disabled_tools: list[str] | None = None filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" diff --git a/surfsense_backend/app/schemas/surfsense_docs.py b/surfsense_backend/app/schemas/surfsense_docs.py deleted file mode 100644 index 3adf25032..000000000 --- a/surfsense_backend/app/schemas/surfsense_docs.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Schemas for Surfsense documentation. -""" - -from datetime import datetime - -from pydantic import BaseModel, ConfigDict - - -class SurfsenseDocsChunkRead(BaseModel): - """Schema for a Surfsense docs chunk.""" - - id: int - content: str - - model_config = ConfigDict(from_attributes=True) - - -class SurfsenseDocsDocumentRead(BaseModel): - """Schema for a Surfsense docs document (without chunks).""" - - id: int - title: str - source: str - public_url: str - content: str - created_at: datetime | None = None - updated_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - -class SurfsenseDocsDocumentWithChunksRead(BaseModel): - """Schema for a Surfsense docs document with its chunks.""" - - id: int - title: str - source: str - public_url: str - content: str - chunks: list[SurfsenseDocsChunkRead] - - model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index d73a0d4ce..920f51d84 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -835,7 +835,14 @@ class ComposioService: ) if not result.get("success"): - return [], None, result.get("error", "Unknown error") + # 4-tuple to match this function's declared return shape + # ``(messages, next_page_token, result_size_estimate, error)``. + # The error branch previously dropped the + # ``result_size_estimate`` slot, which crashed the caller's + # unpack with ``ValueError: not enough values to unpack + # (expected 4, got 3)`` and hid the real Composio error + # (e.g. expired connected account / invalid API key). + return [], None, None, result.get("error", "Unknown error") data = result.get("data", {}) diff --git a/surfsense_backend/app/services/gmail/kb_sync_service.py b/surfsense_backend/app/services/gmail/kb_sync_service.py index 6ff5f3c2b..85e25fcb6 100644 --- a/surfsense_backend/app/services/gmail/kb_sync_service.py +++ b/surfsense_backend/app/services/gmail/kb_sync_service.py @@ -101,9 +101,7 @@ class GmailKBSyncService: else: logger.warning("No LLM configured -- using fallback summary") summary_content = f"Gmail Message: {subject}\n\n{indexable_content}" - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(indexable_content) now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 1f017ec4d..e59868aff 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -116,9 +116,7 @@ class GoogleCalendarKBSyncService: summary_content = ( f"Google Calendar Event: {event_summary}\n\n{indexable_content}" ) - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(indexable_content) now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -297,9 +295,7 @@ class GoogleCalendarKBSyncService: summary_content = ( f"Google Calendar Event: {event_summary}\n\n{indexable_content}" ) - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(indexable_content) now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/surfsense_backend/app/services/jira/kb_sync_service.py b/surfsense_backend/app/services/jira/kb_sync_service.py index 5f6668377..37001a476 100644 --- a/surfsense_backend/app/services/jira/kb_sync_service.py +++ b/surfsense_backend/app/services/jira/kb_sync_service.py @@ -98,9 +98,7 @@ class JiraKBSyncService: summary_content = ( f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}" ) - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(issue_content) now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -214,9 +212,7 @@ class JiraKBSyncService: summary_content = ( f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}" ) - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(issue_content) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index fa97fb33a..aadb60cde 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -682,11 +682,7 @@ def get_planner_llm() -> ChatLiteLLM | None: from app.agents.new_chat.llm_config import create_chat_litellm_from_config planner_cfg = next( - ( - cfg - for cfg in config.GLOBAL_LLM_CONFIGS - if cfg.get("is_planner") is True - ), + (cfg for cfg in config.GLOBAL_LLM_CONFIGS if cfg.get("is_planner") is True), None, ) if not planner_cfg: diff --git a/surfsense_backend/app/services/onedrive/kb_sync_service.py b/surfsense_backend/app/services/onedrive/kb_sync_service.py index e1da3b4a1..731f081dd 100644 --- a/surfsense_backend/app/services/onedrive/kb_sync_service.py +++ b/surfsense_backend/app/services/onedrive/kb_sync_service.py @@ -96,9 +96,7 @@ class OneDriveKBSyncService: else: logger.warning("No LLM configured — using fallback summary") summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}" - summary_embedding = await asyncio.to_thread( - embed_text, summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, summary_content) chunks = await create_document_chunks(indexable_content) now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/surfsense_backend/app/session_events.py b/surfsense_backend/app/session_events.py new file mode 100644 index 000000000..048df2b46 --- /dev/null +++ b/surfsense_backend/app/session_events.py @@ -0,0 +1,103 @@ +"""SQLAlchemy session event hooks — wired once at app startup. + +Detects document folder arrivals across every ORM commit and publishes +``document.entered_folder`` events to the bus after the transaction is durable. +""" + +from __future__ import annotations + +import asyncio +import logging + +from sqlalchemy import event +from sqlalchemy.orm import Session, attributes + +from app.db import Document, DocumentStatus +from app.event_bus.bus import EventBus, bus as default_bus +from app.event_bus.events.document_entered_folder import payload_if_entered_folder + +logger = logging.getLogger(__name__) + +_PENDING_KEY = "_entered_folder_pending" + + +def _after_flush(session: Session, flush_context: object) -> None: + """Collect folder-arrival candidates while attribute history is still available.""" + pending: list[dict] = [] + + for obj in list(session.new) + list(session.dirty): + if not isinstance(obj, Document): + continue + + history = attributes.get_history(obj, "folder_id") + if not history.added: + continue + + new_folder_id = history.added[0] + previous_folder_id = history.deleted[0] if history.deleted else None + + result = payload_if_entered_folder( + document_id=obj.id, + search_space_id=obj.search_space_id, + new_folder_id=new_folder_id, + previous_folder_id=previous_folder_id, + folder_id_changed=True, + status_state=DocumentStatus.get_state(obj.status) or "", + document_type=obj.document_type.value if obj.document_type else "", + title=obj.title or "", + connector_id=obj.connector_id, + created_by_id=str(obj.created_by_id) if obj.created_by_id else None, + ) + if result is not None: + pending.append(result) + + setattr(session, _PENDING_KEY, pending) + + +def _after_commit(session: Session) -> None: + """Publish collected events now that the transaction is durable.""" + pending: list[dict] = getattr(session, _PENDING_KEY, []) + if not pending: + return + setattr(session, _PENDING_KEY, []) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + logger.warning("No running event loop — skipping %d event(s)", len(pending)) + return + + tasks = [ + loop.create_task( + default_bus.publish( + item["event_type"], + item["payload"], + search_space_id=item["search_space_id"], + ) + ) + for item in pending + ] + for task in tasks: + task.add_done_callback( + lambda t: ( + logger.error("event publish failed: %s", t.exception()) + if not t.cancelled() and t.exception() + else None + ) + ) + + +def _after_rollback(session: Session) -> None: + """Discard any pending events — the transaction did not commit.""" + setattr(session, _PENDING_KEY, []) + + +def register_session_hooks(bus: EventBus = default_bus) -> None: + """Register document folder-arrival hooks on the SQLAlchemy Session class. + + Call once at application startup (e.g. in ``app.app`` lifespan). Idempotent + — SQLAlchemy deduplicates identical listener registrations. + """ + event.listen(Session, "after_flush", _after_flush) + event.listen(Session, "after_commit", _after_commit) + event.listen(Session, "after_rollback", _after_rollback) diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index b23359f36..6ea7a2e68 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -37,6 +37,10 @@ def get_celery_session_maker() -> async_sessionmaker: poolclass=NullPool, echo=False, ) + with contextlib.suppress(Exception): + from app.observability.bootstrap import instrument_sqlalchemy_engine + + instrument_sqlalchemy_engine(_celery_engine) _celery_session_maker = async_sessionmaker( _celery_engine, expire_on_commit=False ) diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index 08d96cfa0..50f757473 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -1,14 +1,52 @@ """Celery tasks for connector indexing.""" import logging +import time import traceback +from collections.abc import Awaitable, Callable + +from celery import current_task from app.celery_app import celery_app -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task +from app.observability import metrics as ot_metrics, otel as ot +from app.tasks.celery_tasks import ( + get_celery_session_maker, + run_async_celery_task as _run_async_celery_task, +) logger = logging.getLogger(__name__) +def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T: + """Run connector sync work and record aggregate connector metrics.""" + task_name = getattr(current_task, "name", None) or "unknown" + t0 = time.perf_counter() + status = "failed" + error_category: str | None = None + try: + with ot.connector_sync_span(connector_type=task_name) as sp: + try: + result = _run_async_celery_task(coro_factory) + sp.set_attribute("connector.status", "success") + except Exception as exc: + error_category = ot_metrics.categorize_exception(exc) + sp.set_attribute("connector.error.category", error_category) + raise + status = "success" + return result + finally: + elapsed_s = time.perf_counter() - t0 + ot_metrics.record_connector_sync_duration( + elapsed_s, + connector_type=task_name, + ) + ot_metrics.record_connector_sync_outcome( + connector_type=task_name, + status=status, + error_category=error_category, + ) + + def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> None: """ Handle greenlet_spawn errors with detailed logging for debugging. diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index c78e376bd..1f9609968 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -9,6 +9,7 @@ from uuid import UUID from app.celery_app import celery_app from app.config import config +from app.observability import metrics as ot_metrics from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task @@ -59,7 +60,9 @@ def _start_heartbeat(notification_id: int) -> None: try: key = _get_heartbeat_key(notification_id) _get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "started") + ot_metrics.record_celery_heartbeat_refresh(heartbeat_type="document") except Exception as e: + ot_metrics.record_celery_heartbeat_failure(heartbeat_type="document") logger.warning( f"Failed to set initial heartbeat for notification {notification_id}: {e}" ) @@ -87,7 +90,9 @@ async def _run_heartbeat_loop(notification_id: int): await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL) try: _get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "alive") + ot_metrics.record_celery_heartbeat_refresh(heartbeat_type="document") except Exception as e: + ot_metrics.record_celery_heartbeat_failure(heartbeat_type="document") logger.warning( f"Failed to refresh heartbeat for notification {notification_id}: {e}" ) diff --git a/surfsense_backend/app/tasks/chat/persistence.py b/surfsense_backend/app/tasks/chat/persistence.py index 37be50705..9d100c13c 100644 --- a/surfsense_backend/app/tasks/chat/persistence.py +++ b/surfsense_backend/app/tasks/chat/persistence.py @@ -109,7 +109,7 @@ def _build_user_content( [{"type": "text", "text": "..."}, {"type": "image", "image": "data:..."}, {"type": "mentioned-documents", "documents": [{"id": int, - "title": str, "document_type": str, "kind": "doc" | "folder"}, + "title": str, "kind": "doc" | "folder" | "connector", ...}, ...]}] The companion reader is @@ -117,8 +117,8 @@ def _build_user_content( which expects exactly this shape — keep them in sync. ``mentioned_documents``: optional list of mention chip dicts. Each - dict may include a ``kind`` discriminator (``"doc"`` or ``"folder"``) - so the persisted ContentPart round-trips folder chips on reload. + dict may include a ``kind`` discriminator so the persisted + ContentPart round-trips folder and connector chips on reload. When ``kind`` is missing we default to ``"doc"`` so legacy clients that haven't migrated to the union schema still persist correctly. """ @@ -134,18 +134,27 @@ def _build_user_content( doc_id = doc.get("id") title = doc.get("title") document_type = doc.get("document_type") - if doc_id is None or title is None or document_type is None: - continue kind_raw = doc.get("kind", "doc") - kind = kind_raw if kind_raw in ("doc", "folder") else "doc" - normalized.append( - { - "id": doc_id, - "title": str(title), - "document_type": str(document_type), - "kind": kind, - } - ) + kind = kind_raw if kind_raw in ("doc", "folder", "connector") else "doc" + if doc_id is None or title is None: + continue + if kind == "doc" and document_type is None: + continue + item = { + "id": doc_id, + "title": str(title), + "kind": kind, + } + if document_type is not None: + item["document_type"] = str(document_type) + if kind == "connector": + connector_type = doc.get("connector_type") + if connector_type is None: + continue + account_name = doc.get("account_name") or title + item["connector_type"] = str(connector_type) + item["account_name"] = str(account_name) + normalized.append(item) if normalized: parts.append({"type": "mentioned-documents", "documents": normalized}) return parts diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c9faa1691..e150cf494 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -14,6 +14,7 @@ import contextlib import gc import json import logging +import sys import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field @@ -24,7 +25,6 @@ from uuid import UUID import anyio from langchain_core.messages import HumanMessage from sqlalchemy.future import select -from sqlalchemy.orm import selectinload from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent @@ -54,10 +54,10 @@ from app.db import ( NewChatThread, Report, SearchSourceConnectorType, - SurfsenseDocsDocument, async_session_maker, shielded_async_session, ) +from app.observability import metrics as ot_metrics, otel as ot from app.prompts import TITLE_GENERATION_PROMPT from app.services.auto_model_pin_service import ( mark_runtime_cooldown, @@ -75,7 +75,6 @@ from app.tasks.chat.streaming.helpers.interrupt_inspector import ( ) from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap -from app.utils.surfsense_docs import surfsense_docs_public_url from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() @@ -196,58 +195,6 @@ def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: return out -def format_mentioned_surfsense_docs_as_context( - documents: list[SurfsenseDocsDocument], -) -> str: - """Format mentioned SurfSense documentation as context for the agent.""" - if not documents: - return "" - - context_parts = [""] - context_parts.append( - "The user has explicitly mentioned the following SurfSense documentation pages. " - "These are official documentation about how to use SurfSense and should be used to answer questions about the application. " - "Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])." - ) - - for doc in documents: - public_url = surfsense_docs_public_url(doc.source) - metadata_json = json.dumps( - {"source": doc.source, "public_url": public_url}, ensure_ascii=False - ) - - context_parts.append("") - context_parts.append("") - context_parts.append(f" doc-{doc.id}") - context_parts.append(" SURFSENSE_DOCS") - context_parts.append(f" <![CDATA[{doc.title}]]>") - context_parts.append(f" ") - context_parts.append( - f" " - ) - context_parts.append("") - context_parts.append("") - context_parts.append("") - - if hasattr(doc, "chunks") and doc.chunks: - for chunk in doc.chunks: - context_parts.append( - f" " - ) - else: - context_parts.append( - f" " - ) - - context_parts.append("") - context_parts.append("") - context_parts.append("") - - context_parts.append("") - - return "\n".join(context_parts) - - def extract_todos_from_deepagents(command_output) -> dict: """ Extract todos from deepagents' TodoListMiddleware Command output. @@ -835,8 +782,9 @@ async def stream_new_chat( user_id: str | None = None, llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, - mentioned_surfsense_doc_ids: list[int] | None = None, mentioned_folder_ids: list[int] | None = None, + mentioned_connector_ids: list[int] | None = None, + mentioned_connectors: list[dict[str, Any]] | None = None, mentioned_documents: list[dict[str, Any]] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, @@ -865,7 +813,6 @@ async def stream_new_chat( llm_config_id: The LLM configuration ID (default: -1 for first global config) needs_history_bootstrap: If True, load message history from DB (for cloned chats) mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat - mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode) checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) @@ -883,6 +830,20 @@ async def stream_new_chat( stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" stream_result.filesystem_mode = fs_mode stream_result.client_platform = fs_platform + chat_agent_mode = "unknown" + chat_outcome = "success" + chat_error_category: str | None = None + chat_span_cm = ot.chat_request_span( + chat_id=chat_id, + search_space_id=search_space_id, + flow=flow, + request_id=request_id, + turn_id=stream_result.turn_id, + filesystem_mode=fs_mode, + client_platform=fs_platform, + agent_mode=chat_agent_mode, + ) + chat_span = chat_span_cm.__enter__() _log_file_contract("turn_start", stream_result) _perf_log.info( "[stream_new_chat] filesystem_mode=%s client_platform=%s", @@ -971,6 +932,14 @@ async def stream_new_chat( requires_image_input=_requires_image_input, ) ).resolved_llm_config_id + ot.add_event( + "model.pin.resolved", + { + "pin.requested_id": requested_llm_config_id, + "pin.resolved_id": llm_config_id, + "pin.requires_image_input": _requires_image_input, + }, + ) except ValueError as pin_error: # Auto-pin's "no vision-capable cfg" path raises a ValueError # whose message we map to the friendly image-input SSE error @@ -987,6 +956,13 @@ async def stream_new_chat( if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" else "server_error" ) + if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT": + ot.add_event( + "quota.denied", + { + "quota.code": error_code, + }, + ) yield _emit_stream_error( message=str(pin_error), error_kind=error_kind, @@ -1041,6 +1017,12 @@ async def stream_new_chat( model_label = ( agent_config.config_name or agent_config.model_name or "model" ) + ot.add_event( + "quota.denied", + { + "quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + }, + ) yield _emit_stream_error( message=( f"The selected model ({model_label}) does not support " @@ -1084,6 +1066,12 @@ async def stream_new_chat( ) _premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: + ot.add_event( + "quota.denied", + { + "quota.code": "PREMIUM_QUOTA_EXHAUSTED", + }, + ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -1097,6 +1085,13 @@ async def stream_new_chat( requires_image_input=_requires_image_input, ) ).resolved_llm_config_id + ot.add_event( + "model.repin", + { + "repin.reason": "premium_quota_exhausted", + "repin.to_config_id": llm_config_id, + }, + ) except ValueError as pin_error: yield _emit_stream_error( message=str(pin_error), @@ -1189,6 +1184,9 @@ async def stream_new_chat( from app.config import config as _app_config use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) + chat_agent_mode = "multi" if use_multi_agent else "single" + with contextlib.suppress(Exception): + chat_span.set_attribute("agent.mode", chat_agent_mode) _t0 = time.perf_counter() agent_factory = ( @@ -1240,19 +1238,7 @@ async def stream_new_chat( # Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware # which merges them into the scoped filesystem with full document - # structure. Only SurfSense docs and report context are inlined here. - - # Fetch mentioned SurfSense docs if any - mentioned_surfsense_docs: list[SurfsenseDocsDocument] = [] - if mentioned_surfsense_doc_ids: - result = await session.execute( - select(SurfsenseDocsDocument) - .options(selectinload(SurfsenseDocsDocument.chunks)) - .filter( - SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids), - ) - ) - mentioned_surfsense_docs = list(result.scalars().all()) + # structure. Only report context is inlined here. # Fetch the most recent report(s) in this thread so the LLM can # easily find report_id for versioning decisions, instead of @@ -1286,10 +1272,7 @@ async def stream_new_chat( agent_user_query = user_query accepted_folder_ids: list[int] = [] if fs_mode == FilesystemMode.CLOUD.value and ( - mentioned_document_ids - or mentioned_surfsense_doc_ids - or mentioned_folder_ids - or mentioned_documents + mentioned_document_ids or mentioned_folder_ids or mentioned_documents ): from app.schemas.new_chat import ( MentionedDocumentInfo as _MentionedDocumentInfo, @@ -1315,22 +1298,43 @@ async def stream_new_chat( search_space_id=search_space_id, mentioned_documents=chip_objs, mentioned_document_ids=mentioned_document_ids, - mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, mentioned_folder_ids=mentioned_folder_ids, ) agent_user_query = substitute_in_text(user_query, resolved.token_to_path) accepted_folder_ids = resolved.mentioned_folder_ids - # Format the user query with context (SurfSense docs + reports only). + # Format the user query with context (reports only). # Uses ``agent_user_query`` so the LLM sees backtick-wrapped paths # instead of bare ``@title`` tokens. final_query = agent_user_query context_parts = [] - if mentioned_surfsense_docs: - context_parts.append( - format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs) - ) + if mentioned_connectors: + connector_lines = [] + for connector in mentioned_connectors: + if not isinstance(connector, dict): + continue + connector_id = connector.get("id") + connector_type = connector.get("connector_type") or connector.get( + "document_type" + ) + account_name = connector.get("account_name") or connector.get("title") + if connector_id is None or connector_type is None: + continue + connector_lines.append( + f' - connector_id={connector_id}, connector_type="{connector_type}", ' + f'account_name="{account_name or ""}"' + ) + if connector_lines: + context_parts.append( + "\n" + "The user selected these exact connector accounts with @. " + "These entries are selection metadata, not retrieved connector content. " + "When a connector-backed tool needs an account, use the matching " + "connector_id from this list if the tool supports connector_id:\n" + + "\n".join(connector_lines) + + "\n" + ) # Surface report IDs prominently so the LLM doesn't have to # retrieve them from old tool responses in conversation history. @@ -1535,12 +1539,8 @@ async def stream_new_chat( stream_result.content_builder = AssistantContentBuilder() # Initial thinking step - analyzing the request - if mentioned_surfsense_docs: - initial_title = "Analyzing referenced content" - action_verb = "Analyzing" - else: - initial_title = "Understanding your request" - action_verb = "Processing" + initial_title = "Understanding your request" + action_verb = "Processing" processing_parts = [] if user_query.strip(): @@ -1551,18 +1551,6 @@ async def stream_new_chat( else: processing_parts.append("(message)") - if mentioned_surfsense_docs: - doc_names = [] - for doc in mentioned_surfsense_docs: - title = doc.title - if len(title) > 30: - title = title[:27] + "..." - doc_names.append(title) - if len(doc_names) == 1: - processing_parts.append(f"[{doc_names[0]}]") - else: - processing_parts.append(f"[{len(doc_names)} docs]") - initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] initial_step_id = "thinking-1" @@ -1582,10 +1570,10 @@ async def stream_new_chat( items=initial_items, ) - # These ORM objects (with eagerly-loaded chunks) can be very large. - # They're only needed to build context strings already copied into - # final_query / langchain_messages — release them before streaming. - del mentioned_surfsense_docs, recent_reports + # These ORM objects can be large. They're only needed to build context + # strings already copied into final_query / langchain_messages — + # release them before streaming. + del recent_reports del langchain_messages, final_query # Check if this is the first assistant response so we can generate @@ -1725,6 +1713,8 @@ async def stream_new_chat( mentioned_folder_ids=list( accepted_folder_ids or mentioned_folder_ids or [] ), + mentioned_connector_ids=list(mentioned_connector_ids or []), + mentioned_connectors=list(mentioned_connectors or []), request_id=request_id, turn_id=stream_result.turn_id, ) @@ -1863,6 +1853,14 @@ async def stream_new_chat( llm_config_id, time.perf_counter() - _t0, ) + ot.add_event( + "chat.rate_limit.recovered", + { + "recovery.reason": "provider_rate_limited", + "recovery.previous_config_id": previous_config_id, + "recovery.fallback_config_id": llm_config_id, + }, + ) _log_chat_stream_error( flow=flow, error_kind="rate_limited", @@ -1893,6 +1891,12 @@ async def stream_new_chat( log_system_snapshot("stream_new_chat_END") if stream_result.is_interrupted: + ot.add_event( + "chat.interrupted", + { + "chat.flow": flow, + }, + ) if title_task is not None and not title_task.done(): title_task.cancel() @@ -2011,6 +2015,12 @@ async def stream_new_chat( user_message, error_extra, ) = _classify_stream_exception(e, flow_label="chat") + chat_outcome = error_code or error_kind or "error" + chat_error_category = ot_metrics.categorize_exception(e) + with contextlib.suppress(Exception): + chat_span.set_attribute("chat.outcome", chat_outcome) + chat_span.set_attribute("error.category", chat_error_category) + ot.record_error(chat_span, e) error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") @@ -2201,6 +2211,21 @@ async def stream_new_chat( ) trim_native_heap() log_system_snapshot("stream_new_chat_END") + with contextlib.suppress(Exception): + chat_span.set_attribute("chat.outcome", chat_outcome) + ot_metrics.record_chat_request_duration( + (time.perf_counter() - _t_total) * 1000, + flow=flow, + outcome=chat_outcome, + agent_mode=chat_agent_mode, + ) + ot_metrics.record_chat_request_outcome( + flow=flow, + outcome=chat_outcome, + agent_mode=chat_agent_mode, + error_category=chat_error_category, + ) + chat_span_cm.__exit__(*sys.exc_info()) async def stream_resume_chat( @@ -2225,6 +2250,20 @@ async def stream_resume_chat( stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" stream_result.filesystem_mode = fs_mode stream_result.client_platform = fs_platform + chat_agent_mode = "unknown" + chat_outcome = "success" + chat_error_category: str | None = None + chat_span_cm = ot.chat_request_span( + chat_id=chat_id, + search_space_id=search_space_id, + flow="resume", + request_id=request_id, + turn_id=stream_result.turn_id, + filesystem_mode=fs_mode, + client_platform=fs_platform, + agent_mode=chat_agent_mode, + ) + chat_span = chat_span_cm.__enter__() _log_file_contract("turn_start", stream_result) _perf_log.info( "[stream_resume] filesystem_mode=%s client_platform=%s", @@ -2297,6 +2336,14 @@ async def stream_resume_chat( selected_llm_config_id=llm_config_id, ) ).resolved_llm_config_id + ot.add_event( + "model.pin.resolved", + { + "pin.requested_id": requested_llm_config_id, + "pin.resolved_id": llm_config_id, + "pin.requires_image_input": False, + }, + ) except ValueError as pin_error: yield _emit_stream_error( message=str(pin_error), @@ -2353,6 +2400,12 @@ async def stream_resume_chat( ) _resume_premium_reserved_micros = reserve_amount_micros if not quota_result.allowed: + ot.add_event( + "quota.denied", + { + "quota.code": "PREMIUM_QUOTA_EXHAUSTED", + }, + ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -2365,6 +2418,13 @@ async def stream_resume_chat( force_repin_free=True, ) ).resolved_llm_config_id + ot.add_event( + "model.repin", + { + "repin.reason": "premium_quota_exhausted", + "repin.to_config_id": llm_config_id, + }, + ) except ValueError as pin_error: yield _emit_stream_error( message=str(pin_error), @@ -2454,6 +2514,9 @@ async def stream_resume_chat( visibility = thread_visibility or ChatVisibility.PRIVATE from app.config import config as _app_config + chat_agent_mode = "multi" if _app_config.MULTI_AGENT_CHAT_ENABLED else "single" + with contextlib.suppress(Exception): + chat_span.set_attribute("agent.mode", chat_agent_mode) _t0 = time.perf_counter() agent_factory = ( create_multi_agent_chat_deep_agent @@ -2695,6 +2758,14 @@ async def stream_resume_chat( llm_config_id, time.perf_counter() - _t0, ) + ot.add_event( + "chat.rate_limit.recovered", + { + "recovery.reason": "provider_rate_limited", + "recovery.previous_config_id": previous_config_id, + "recovery.fallback_config_id": llm_config_id, + }, + ) _log_chat_stream_error( flow="resume", error_kind="rate_limited", @@ -2722,6 +2793,12 @@ async def stream_resume_chat( chat_id, ) if stream_result.is_interrupted: + ot.add_event( + "chat.interrupted", + { + "chat.flow": "resume", + }, + ) usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", @@ -2815,6 +2892,12 @@ async def stream_resume_chat( user_message, error_extra, ) = _classify_stream_exception(e, flow_label="resume") + chat_outcome = error_code or error_kind or "error" + chat_error_category = ot_metrics.categorize_exception(e) + with contextlib.suppress(Exception): + chat_span.set_attribute("chat.outcome", chat_outcome) + chat_span.set_attribute("error.category", chat_error_category) + ot.record_error(chat_span, e) error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") @@ -2964,3 +3047,18 @@ async def stream_resume_chat( ) trim_native_heap() log_system_snapshot("stream_resume_chat_END") + with contextlib.suppress(Exception): + chat_span.set_attribute("chat.outcome", chat_outcome) + ot_metrics.record_chat_request_duration( + (time.perf_counter() - _t_total) * 1000, + flow="resume", + outcome=chat_outcome, + agent_mode=chat_agent_mode, + ) + ot_metrics.record_chat_request_outcome( + flow="resume", + outcome=chat_outcome, + agent_mode=chat_agent_mode, + error_category=chat_error_category, + ) + chat_span_cm.__exit__(*sys.exc_info()) diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/__init__.py b/surfsense_backend/app/tasks/chat/streaming/agent/__init__.py new file mode 100644 index 000000000..260dcb3f2 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/agent/__init__.py @@ -0,0 +1,8 @@ +"""Agent construction and per-turn event-loop drivers.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread +from app.tasks.chat.streaming.agent.event_loop import stream_agent_events + +__all__ = ["build_main_agent_for_thread", "stream_agent_events"] diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py new file mode 100644 index 000000000..0db42edbf --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py @@ -0,0 +1,49 @@ +"""Single per-thread agent (re)build path. + +A graph swap mid-turn would corrupt checkpointer state for the same +``thread_id``, so both the initial build and any mid-stream 429 recovery rebuild +must funnel through this single function. +""" + +from __future__ import annotations + +from typing import Any + +from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.agents.new_chat.llm_config import AgentConfig +from app.db import ChatVisibility +from app.services.connector_service import ConnectorService + + +async def build_main_agent_for_thread( + agent_factory: Any, + *, + llm: Any, + search_space_id: int, + db_session: Any, + connector_service: ConnectorService, + checkpointer: Any, + user_id: str | None, + thread_id: int | None, + agent_config: AgentConfig | None, + firecrawl_api_key: str | None, + thread_visibility: ChatVisibility | None, + filesystem_selection: FilesystemSelection | None, + disabled_tools: list[str] | None = None, + mentioned_document_ids: list[int] | None = None, +) -> Any: + return await agent_factory( + llm=llm, + search_space_id=search_space_id, + db_session=db_session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=thread_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=thread_visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py new file mode 100644 index 000000000..b77bd3890 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -0,0 +1,175 @@ +"""Per-turn agent event-loop driver. + +Drives ``stream_output`` (graph_stream relay) for one agent turn, then runs the +post-stream agent-state inspection: safety-net commit of any staged filesystem +state (in case ``aafter_agent`` was skipped), file-operation contract scoring, +intent classification, and interrupt detection. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from typing import Any + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.kb_persistence import ( + commit_staged_filesystem_state, +) +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.contract.file_contract import ( + contract_enforcement_active, + evaluate_file_contract_outcome, + log_file_contract, +) +from app.tasks.chat.streaming.graph_stream.event_stream import stream_output +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + all_interrupt_values, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.tasks.chat.streaming.shared.utils import safe_float +from app.utils.perf import get_perf_logger + +_perf_log = get_perf_logger() + + +async def stream_agent_events( + agent: Any, + config: dict[str, Any], + input_data: Any, + streaming_service: VercelStreamingService, + result: StreamResult, + step_prefix: str = "thinking", + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + *, + fallback_commit_search_space_id: int | None = None, + fallback_commit_created_by_id: str | None = None, + fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + fallback_commit_thread_id: int | None = None, + runtime_context: Any = None, + content_builder: Any | None = None, +) -> AsyncGenerator[str, None]: + """Stream and format ``astream_events`` from the agent. + + Yields SSE-formatted strings; after exhausting, ``result`` carries + ``accumulated_text`` and interrupt state. See ``StreamResult`` for the + side-channel surface populated by the underlying relay. + """ + async for sse in stream_output( + agent=agent, + config=config, + input_data=input_data, + streaming_service=streaming_service, + result=result, + step_prefix=step_prefix, + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + content_builder=content_builder, + runtime_context=runtime_context, + ): + yield sse + + accumulated_text = result.accumulated_text + + state = await agent.aget_state(config) + state_values = getattr(state, "values", {}) or {} + + # Safety net: if astream_events was cancelled before + # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work + # (dirty_paths / staged_dirs / pending_moves / pending_deletes / + # pending_dir_deletes) is still in the checkpointed state. Run the SAME + # shared commit helper so the turn's writes don't get lost on client + # disconnect, then push the delta back into the graph using ``as_node=...`` + # so reducers fire as if the after_agent hook produced it. + if ( + fallback_commit_filesystem_mode == FilesystemMode.CLOUD + and fallback_commit_search_space_id is not None + and ( + (state_values.get("dirty_paths") or []) + or (state_values.get("staged_dirs") or []) + or (state_values.get("pending_moves") or []) + or (state_values.get("pending_deletes") or []) + or (state_values.get("pending_dir_deletes") or []) + ) + ): + try: + delta = await commit_staged_filesystem_state( + state_values, + search_space_id=fallback_commit_search_space_id, + created_by_id=fallback_commit_created_by_id, + filesystem_mode=fallback_commit_filesystem_mode, + thread_id=fallback_commit_thread_id, + dispatch_events=False, + ) + if delta: + await agent.aupdate_state( + config, + delta, + as_node="KnowledgeBasePersistenceMiddleware.after_agent", + ) + except Exception as exc: + _perf_log.warning("[stream_agent_events] safety-net commit failed: %s", exc) + + contract_state = state_values.get("file_operation_contract") or {} + contract_turn_id = contract_state.get("turn_id") + current_turn_id = config.get("configurable", {}).get("turn_id", "") + intent_value = contract_state.get("intent") + if ( + isinstance(intent_value, str) + and intent_value in ("chat_only", "file_write", "file_read") + and contract_turn_id == current_turn_id + ): + result.intent_detected = intent_value + if ( + isinstance(intent_value, str) + and intent_value in ("chat_only", "file_write", "file_read") + and contract_turn_id != current_turn_id + ): + # Ignore stale intent contracts from previous turns/checkpoints. + result.intent_detected = "chat_only" + result.intent_confidence = ( + safe_float(contract_state.get("confidence"), default=0.0) + if contract_turn_id == current_turn_id + else 0.0 + ) + + if result.intent_detected == "file_write": + result.commit_gate_passed, result.commit_gate_reason = ( + evaluate_file_contract_outcome(result) + ) + if not result.commit_gate_passed and contract_enforcement_active(result): + gate_notice = ( + "I could not complete the requested file write because no successful " + "write_file/edit_file operation was confirmed." + ) + gate_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(gate_text_id) + if content_builder is not None: + content_builder.on_text_start(gate_text_id) + yield streaming_service.format_text_delta(gate_text_id, gate_notice) + if content_builder is not None: + content_builder.on_text_delta(gate_text_id, gate_notice) + yield streaming_service.format_text_end(gate_text_id) + if content_builder is not None: + content_builder.on_text_end(gate_text_id) + yield streaming_service.format_terminal_info(gate_notice, "error") + accumulated_text = gate_notice + else: + result.commit_gate_passed = True + result.commit_gate_reason = "" + + result.accumulated_text = accumulated_text + log_file_contract("turn_outcome", result) + + pending_values = all_interrupt_values(state) + if pending_values: + result.is_interrupted = True + # One frame per paused subagent so each parallel HITL renders its own + # approval card on the wire. Order matches ``state.interrupts``, which + # the resume slicer in + # ``checkpointed_subagent_middleware.resume_routing`` consumes in the + # same order — keeping emit and resume in lock-step. + for interrupt_value in pending_values: + yield streaming_service.format_interrupt_request(interrupt_value) diff --git a/surfsense_backend/app/tasks/chat/streaming/context/__init__.py b/surfsense_backend/app/tasks/chat/streaming/context/__init__.py new file mode 100644 index 000000000..4cf58d76f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/context/__init__.py @@ -0,0 +1,11 @@ +"""Pre-agent context shaping: todos extraction.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.context.deepagents_todos import ( + extract_todos_from_deepagents, +) + +__all__ = [ + "extract_todos_from_deepagents", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/context/deepagents_todos.py b/surfsense_backend/app/tasks/chat/streaming/context/deepagents_todos.py new file mode 100644 index 000000000..b9cbf6506 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/context/deepagents_todos.py @@ -0,0 +1,25 @@ +"""Extract todos from a deepagents ``TodoListMiddleware`` ``Command`` output.""" + +from __future__ import annotations + +from typing import Any + + +def extract_todos_from_deepagents(command_output: Any) -> dict: + """Normalize todos out of a deepagents ``Command`` or dict payload. + + deepagents returns a ``Command`` whose ``update['todos']`` is a list of + ``{'content': str, 'status': str}`` dicts. The UI expects the same shape, + so no transformation is required — only extraction. + """ + todos_data: list = [] + if hasattr(command_output, "update"): + update = command_output.update + todos_data = update.get("todos", []) + elif isinstance(command_output, dict): + if "todos" in command_output: + todos_data = command_output.get("todos", []) + elif "update" in command_output and isinstance(command_output["update"], dict): + todos_data = command_output["update"].get("todos", []) + + return {"todos": todos_data} diff --git a/surfsense_backend/app/tasks/chat/streaming/contract/__init__.py b/surfsense_backend/app/tasks/chat/streaming/contract/__init__.py new file mode 100644 index 000000000..4562b362c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/contract/__init__.py @@ -0,0 +1,15 @@ +"""File-operation contract evaluation and logging.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.contract.file_contract import ( + contract_enforcement_active, + evaluate_file_contract_outcome, + log_file_contract, +) + +__all__ = [ + "contract_enforcement_active", + "evaluate_file_contract_outcome", + "log_file_contract", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/contract/file_contract.py b/surfsense_backend/app/tasks/chat/streaming/contract/file_contract.py new file mode 100644 index 000000000..f21f5da02 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/contract/file_contract.py @@ -0,0 +1,53 @@ +"""File-operation contract: when to enforce, how to score, how to log.""" + +from __future__ import annotations + +import json +from typing import Any + +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.utils.perf import get_perf_logger + +_perf_log = get_perf_logger() + + +def contract_enforcement_active(result: StreamResult) -> bool: + # Enforce only in desktop local-folder mode. Kept deterministic, no + # env-driven progression modes. + return result.filesystem_mode == "desktop_local_folder" + + +def evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]: + if result.intent_detected != "file_write": + return True, "" + if not result.write_attempted: + return False, "no_write_attempt" + if not result.write_succeeded: + return False, "write_failed" + if not result.verification_succeeded: + return False, "verification_failed" + return True, "" + + +def log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: + payload: dict[str, Any] = { + "stage": stage, + "request_id": result.request_id or "unknown", + "turn_id": result.turn_id or "unknown", + "chat_id": ( + result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown" + ), + "filesystem_mode": result.filesystem_mode, + "client_platform": result.client_platform, + "intent_detected": result.intent_detected, + "intent_confidence": result.intent_confidence, + "write_attempted": result.write_attempted, + "write_succeeded": result.write_succeeded, + "verification_succeeded": result.verification_succeeded, + "commit_gate_passed": result.commit_gate_passed, + "commit_gate_reason": result.commit_gate_reason or None, + } + payload.update(extra) + _perf_log.info( + "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False) + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/__init__.py b/surfsense_backend/app/tasks/chat/streaming/flows/__init__.py new file mode 100644 index 000000000..522db2fad --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/__init__.py @@ -0,0 +1,17 @@ +"""Top-level streaming flows: ``new_chat`` and ``resume_chat`` orchestrators. + +Re-exports the public entry points so callers can write:: + + from app.tasks.chat.streaming.flows import stream_new_chat, stream_resume_chat + +The orchestrators themselves live under ``new_chat/orchestrator.py`` and +``resume_chat/orchestrator.py`` (slim composition of the per-concern modules in +each flow folder and the building blocks in ``shared/``). +""" + +from __future__ import annotations + +from app.tasks.chat.streaming.flows.new_chat import stream_new_chat +from app.tasks.chat.streaming.flows.resume_chat import stream_resume_chat + +__all__ = ["stream_new_chat", "stream_resume_chat"] diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/__init__.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/__init__.py new file mode 100644 index 000000000..566d5e0d9 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/__init__.py @@ -0,0 +1,12 @@ +"""New-chat streaming flow. + +The public entry point ``stream_new_chat`` is the slim coroutine in +``orchestrator.py`` that composes the per-concern modules in this folder and +the building blocks under ``flows/shared/``. +""" + +from __future__ import annotations + +from app.tasks.chat.streaming.flows.new_chat.orchestrator import stream_new_chat + +__all__ = ["stream_new_chat"] diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py new file mode 100644 index 000000000..af496cee7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py @@ -0,0 +1,91 @@ +"""Resolve the auto-pin for the *initial* turn config. + +Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for +this thread / search space / user, optionally filtered to vision-capable +configs when the turn carries images. + +Errors classified here: + + * ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` — the auto-pin pool has no + vision-capable cfg for an image-bearing turn. The same gate fires later + in ``llm_capability`` for explicit selections; mapping both to the same + code keeps the FE error UI consistent. + * ``SERVER_ERROR`` — any other ``ValueError`` from the resolver. + +This module owns *initial* pin resolution; the rate-limit recovery loop has +its own narrower auto-pin call (with ``exclude_config_ids``) in +``flows/shared/rate_limit_recovery``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.observability import otel as ot +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id + + +@dataclass +class AutoPinResult: + """Outcome of ``resolve_initial_auto_pin``. + + ``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the + classified user-facing message plus error code/kind so the orchestrator can + emit one terminal-error SSE frame. + """ + + llm_config_id: int | None + error: tuple[str, str, Literal["user_error", "server_error"]] | None + + +async def resolve_initial_auto_pin( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_id: str | None, + selected_llm_config_id: int, + requires_image_input: bool, + requested_llm_config_id: int, +) -> AutoPinResult: + """Run the resolver and classify any ``ValueError`` for the SSE error path.""" + try: + pinned = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=selected_llm_config_id, + requires_image_input=requires_image_input, + ) + ot.add_event( + "model.pin.resolved", + { + "pin.requested_id": requested_llm_config_id, + "pin.resolved_id": pinned.resolved_llm_config_id, + "pin.requires_image_input": requires_image_input, + }, + ) + return AutoPinResult(llm_config_id=pinned.resolved_llm_config_id, error=None) + except ValueError as pin_error: + # The "no vision-capable cfg" path raises a ValueError whose message + # we map to the friendly image-input SSE error so the user sees the + # same message regardless of whether the gate fired in the resolver or + # in ``llm_capability.assert_vision_capability_for_image_turn``. + is_vision_failure = requires_image_input and "vision-capable" in str(pin_error) + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if is_vision_failure + else "SERVER_ERROR" + ) + error_kind: Literal["user_error", "server_error"] = ( + "user_error" if is_vision_failure else "server_error" + ) + if is_vision_failure: + ot.add_event("quota.denied", {"quota.code": error_code}) + return AutoPinResult( + llm_config_id=None, error=(str(pin_error), error_code, error_kind) + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py new file mode 100644 index 000000000..e727200eb --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py @@ -0,0 +1,77 @@ +"""Build and emit the first ``thinking-1`` step for a new-chat turn. + +The step title and "Processing X" items are derived from what the user sent +(text snippet, image count) so the FE can render a meaningful placeholder +while the agent stream warms up. + +``thinking-1`` is the canonical id for this step — every subsequent +``thinking-N`` produced by ``stream_agent_events`` folds into the same +singleton ``data-thinking-steps`` part on the FE. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +from app.services.new_streaming_service import VercelStreamingService + + +@dataclass +class InitialThinkingStep: + """Resolved fields passed both into the SSE frame and the builder hook. + + ``items`` is the bullet list under the step title; ``title`` is the + one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE + Timeline can de-duplicate against the prior assistant message on resume. + """ + + step_id: str + title: str + items: list[str] + + +def build_initial_thinking_step( + *, + user_query: str, + user_image_data_urls: list[str] | None, +) -> InitialThinkingStep: + title = "Understanding your request" + action_verb = "Processing" + + processing_parts: list[str] = [] + if user_query.strip(): + query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") + processing_parts.append(query_text) + elif user_image_data_urls: + processing_parts.append(f"[{len(user_image_data_urls)} image(s)]") + else: + processing_parts.append("(message)") + + items = [f"{action_verb}: {' '.join(processing_parts)}"] + return InitialThinkingStep(step_id="thinking-1", title=title, items=items) + + +def iter_initial_thinking_step_frame( + step: InitialThinkingStep, + *, + streaming_service: VercelStreamingService, + content_builder: Any | None, +) -> Iterator[str]: + """Drive both the SSE emission and the builder hook for the initial step. + + The FE folds this step into the same singleton ``data-thinking-steps`` part + as everything the agent stream emits later, so we mirror that fold + server-side by driving the builder lifecycle ourselves. + """ + if content_builder is not None: + content_builder.on_thinking_step( + step.step_id, step.title, "in_progress", step.items + ) + yield streaming_service.format_thinking_step( + step_id=step.step_id, + title=step.title, + status="in_progress", + items=step.items, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py new file mode 100644 index 000000000..0c6704bd1 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py @@ -0,0 +1,227 @@ +r"""Assemble the LangGraph ``input_state`` for the new-chat turn. + +Pipeline: + + 1. **History bootstrap** — only for cloned chats with no LangGraph checkpoint + yet; flips the per-thread ``needs_history_bootstrap`` flag back to False + once the rows are loaded. + 2. **Recent reports** — top 3 by id desc with non-null content, so the LLM + can resolve ``report_id`` for versioning without spelunking history. + 3. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the + query with canonical ``\`/documents/...\``` paths the LLM expects. + 4. **Context block render** — XML-wrap recent reports, prepend to the + rewritten query, optionally prefix with display name for SEARCH_SPACE + visibility. + 5. **HumanMessage** — multimodal content if images are attached. + +Returns the assembled ``input_state`` dict plus side-channel data the +orchestrator needs downstream (``accepted_folder_ids`` for runtime context). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +from langchain_core.messages import HumanMessage +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text +from app.db import ( + ChatVisibility, + NewChatThread, + Report, +) +from app.utils.content_utils import bootstrap_history_from_db +from app.utils.user_message_multimodal import build_human_message_content + +logger = logging.getLogger(__name__) + + +@dataclass +class NewChatInputState: + """Everything ``build_new_chat_input_state`` produces. + + ``input_state`` is fed straight to the agent. ``accepted_folder_ids`` + feeds the runtime context (the resolver may have dropped some chips). + """ + + input_state: dict[str, Any] + accepted_folder_ids: list[int] + + +async def build_new_chat_input_state( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_document_ids: list[int] | None, + mentioned_folder_ids: list[int] | None, + mentioned_documents: list[dict[str, Any]] | None, + needs_history_bootstrap: bool, + thread_visibility: ChatVisibility, + current_user_display_name: str | None, + filesystem_mode: str, + request_id: str | None, + turn_id: str, +) -> NewChatInputState: + langchain_messages: list[Any] = [] + + if needs_history_bootstrap: + langchain_messages = await bootstrap_history_from_db( + session, chat_id, thread_visibility=thread_visibility + ) + thread_result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + thread = thread_result.scalars().first() + if thread: + thread.needs_history_bootstrap = False + await session.commit() + + # Top 3 reports keyed by id desc (newest first) with content present, + # surfaced inline so the LLM resolves ``report_id`` for versioning without + # digging through conversation history. + recent_reports_result = await session.execute( + select(Report) + .filter( + Report.thread_id == chat_id, + Report.content.isnot(None), + ) + .order_by(Report.id.desc()) + .limit(3) + ) + recent_reports = list(recent_reports_result.scalars().all()) + + agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query( + session, + search_space_id=search_space_id, + user_query=user_query, + filesystem_mode=filesystem_mode, + mentioned_document_ids=mentioned_document_ids, + mentioned_folder_ids=mentioned_folder_ids, + mentioned_documents=mentioned_documents, + ) + + final_query = _render_query_with_context( + agent_user_query=agent_user_query, + recent_reports=recent_reports, + ) + + if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: + final_query = f"**[{current_user_display_name}]:** {final_query}" + + human_content = build_human_message_content( + final_query, list(user_image_data_urls or ()) + ) + langchain_messages.append(HumanMessage(content=human_content)) + + input_state = { + "messages": langchain_messages, + "search_space_id": search_space_id, + "request_id": request_id or "unknown", + "turn_id": turn_id, + } + + return NewChatInputState( + input_state=input_state, + accepted_folder_ids=accepted_folder_ids, + ) + + +async def _resolve_mentions_for_query( + session: AsyncSession, + *, + search_space_id: int, + user_query: str, + filesystem_mode: str, + mentioned_document_ids: list[int] | None, + mentioned_folder_ids: list[int] | None, + mentioned_documents: list[dict[str, Any]] | None, +) -> tuple[str, list[int]]: + r"""Resolve @-mention chips and rewrite the user query to canonical paths. + + Cloud mode only: local-folder mode keeps the legacy ``@title`` text path + (mention support there is a follow-up task — the path scheme is + mount-rooted and the picker UI both need separate work). + + The substitution lands in the returned ``agent_user_query`` ONLY — the + original ``user_query`` (with ``@title`` tokens) flows untouched into + ``persist_user_turn`` so chip rendering on reload still works + (``UserTextPart`` → ``parseMentionSegments`` matches ``@title``, not + ``\`/documents/...\```). It also feeds the human-readable surfaces — SSE + "Processing X" status, auto thread title, memory seed — which all want + what the user typed. + """ + agent_user_query = user_query + accepted_folder_ids: list[int] = [] + + has_any_mention = bool( + mentioned_document_ids or mentioned_folder_ids or mentioned_documents + ) + if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention: + return agent_user_query, accepted_folder_ids + + from app.schemas.new_chat import MentionedDocumentInfo + + chip_objs: list[MentionedDocumentInfo] | None = None + if mentioned_documents: + chip_objs = [] + for raw in mentioned_documents: + if isinstance(raw, MentionedDocumentInfo): + chip_objs.append(raw) + continue + try: + chip_objs.append(MentionedDocumentInfo.model_validate(raw)) + except Exception: + logger.debug("stream_new_chat: dropping malformed mention chip %r", raw) + + resolved = await resolve_mentions( + session, + search_space_id=search_space_id, + mentioned_documents=chip_objs, + mentioned_document_ids=mentioned_document_ids, + mentioned_folder_ids=mentioned_folder_ids, + ) + agent_user_query = substitute_in_text(user_query, resolved.token_to_path) + accepted_folder_ids = resolved.mentioned_folder_ids + return agent_user_query, accepted_folder_ids + + +def _render_query_with_context( + *, + agent_user_query: str, + recent_reports: list[Report], +) -> str: + """Prepend recent-reports XML block to the user query.""" + context_parts: list[str] = [] + + if recent_reports: + report_lines: list[str] = [] + for r in recent_reports: + report_lines.append( + f' - report_id={r.id}, title="{r.title}", ' + f'style="{r.report_style or "detailed"}"' + ) + reports_listing = "\n".join(report_lines) + context_parts.append( + "\n" + "Previously generated reports in this conversation:\n" + f"{reports_listing}\n\n" + "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of " + "these reports, set parent_report_id to the relevant report_id above.\n" + "If the user wants a completely NEW report on a different topic, " + "leave parent_report_id unset.\n" + "" + ) + + if context_parts: + context = "\n\n".join(context_parts) + return f"{context}\n\n{agent_user_query}" + + return agent_user_query diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py new file mode 100644 index 000000000..9f4e5d2d8 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -0,0 +1,60 @@ +"""Vision-capability gate for image-bearing turns. + +Capability safety net for explicit (non-auto-pin) selections: a turn carrying +user-uploaded images cannot be routed to a chat config that LiteLLM's +authoritative model map *explicitly* marks as text-only (``supports_vision`` +set to False). The check is intentionally narrow — it only fires when LiteLLM +is *certain* the model can't accept image input; unknown / unmapped / +vision-capable models pass through. + +Without this guard a known-text-only model would 404 at the provider with +``"No endpoints found that support image input"``, surfacing as an opaque +``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that +tells the user what to change. +""" + +from __future__ import annotations + +from app.agents.new_chat.llm_config import AgentConfig +from app.observability import otel as ot + + +def check_image_input_capability( + *, + user_image_data_urls: list[str] | None, + agent_config: AgentConfig | None, +) -> tuple[str, str] | None: + """Return ``(user_message, error_code)`` when the gate trips, else ``None``. + + The caller emits one terminal-error SSE frame on a non-``None`` return. + """ + if not (user_image_data_urls and agent_config is not None): + return None + + from app.services.provider_capabilities import is_known_text_only_chat_model + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if not is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + return None + + model_label = agent_config.config_name or agent_config.model_name or "model" + ot.add_event("quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}) + return ( + ( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py new file mode 100644 index 000000000..1892320d3 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -0,0 +1,863 @@ +"""``stream_new_chat`` — public entry point for a fresh chat turn. + +Slim composition layer over the per-concern modules in this folder and the +building blocks under ``flows/shared/``. Each phase corresponds to a numbered +block in the surrounding code so the on-the-wire ordering stays explicit: + + 1. Validation / config — auto-pin, LLM bundle, capability, premium reserve. + 2. Concurrent persistence + pre-stream setup — spawn DB writes, build the + connector, fetch the checkpointer, build the agent. + 3. Input assembly — history bootstrap, mentions, surfsense docs, reports. + 4. First SSE frames — message_start, start_step, turn-info, turn-status. + 5. Persistence join + message-id frames (ghost-thread protection). + 6. Initial thinking step + title task + runtime context. + 7. Stream loop with in-stream rate-limit recovery + mid-stream title emit. + 8. Finalize — premium debit, token-usage SSE, finish frames. + 9. Exception branch — classify, emit terminal error, finish frames. + 10. Finally — premium release, session close, assistant finalize, GC, span. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import time +from collections.abc import AsyncGenerator +from functools import partial +from typing import Any, Literal + +import anyio + +from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection +from app.agents.new_chat.middleware.busy_mutex import end_turn +from app.config import config as _app_config +from app.db import ChatVisibility, async_session_maker +from app.observability import otel as ot +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.content_builder import AssistantContentBuilder +from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread +from app.tasks.chat.streaming.contract.file_contract import log_file_contract +from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error +from app.tasks.chat.streaming.flows.new_chat.auto_pin import resolve_initial_auto_pin +from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import ( + build_initial_thinking_step, + iter_initial_thinking_step_frame, +) +from app.tasks.chat.streaming.flows.new_chat.input_state import ( + build_new_chat_input_state, +) +from app.tasks.chat.streaming.flows.new_chat.llm_capability import ( + check_image_input_capability, +) +from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import ( + await_persist_task, + spawn_persist_assistant_shell_task, + spawn_persist_user_task, + spawn_set_ai_responding_bg, +) +from app.tasks.chat.streaming.flows.new_chat.runtime_context import ( + build_new_chat_runtime_context, +) +from app.tasks.chat.streaming.flows.new_chat.title_gen import ( + await_pending_title_update, + maybe_emit_title_update, + spawn_title_task, +) +from app.tasks.chat.streaming.flows.shared.assistant_finalize import ( + finalize_assistant_message, +) +from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame +from app.tasks.chat.streaming.flows.shared.finally_cleanup import ( + close_session_and_clear_ai_responding, + run_gc_pass, +) +from app.tasks.chat.streaming.flows.shared.first_frames import ( + iter_final_frames, + iter_initial_frames, +) +from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle +from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( + get_chat_checkpointer, + setup_connector_and_firecrawl, +) +from app.tasks.chat.streaming.flows.shared.premium_quota import ( + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, +) +from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( + can_recover_provider_rate_limit, + log_rate_limit_recovered, + reroute_to_next_auto_pin, +) +from app.tasks.chat.streaming.flows.shared.span import ( + close_chat_request_span, + open_chat_request_span, + set_agent_mode, +) +from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop +from app.tasks.chat.streaming.flows.shared.terminal_error import ( + handle_terminal_exception, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.utils.perf import get_perf_logger, log_system_snapshot + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + +# Holds spawned background tasks (set_ai_responding, persist_user, persist_asst) +# so the GC doesn't drop them before they finish. Kept at module level so it +# survives across turns within one process. +_background_tasks: set[asyncio.Task] = set() + + +async def stream_new_chat( + user_query: str, + search_space_id: int, + chat_id: int, + user_id: str | None = None, + llm_config_id: int = -1, + mentioned_document_ids: list[int] | None = None, + mentioned_folder_ids: list[int] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, + checkpoint_id: str | None = None, + needs_history_bootstrap: bool = False, + thread_visibility: ChatVisibility | None = None, + current_user_display_name: str | None = None, + disabled_tools: list[str] | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, + user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "new", +) -> AsyncGenerator[str, None]: + """Stream a new chat turn using the SurfSense deep agent. + + Uses the Vercel AI SDK Data Stream Protocol (SSE). ``chat_id`` is the + LangGraph thread id (durable conversation memory via the checkpointer). + Manages its own database session so cleanup runs even when Starlette + cancels the task on client disconnect. + """ + streaming_service = VercelStreamingService() + stream_result = StreamResult() + _t_total = time.perf_counter() + fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" + fs_platform = ( + filesystem_selection.client_platform.value if filesystem_selection else "web" + ) + stream_result.request_id = request_id + stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" + stream_result.filesystem_mode = fs_mode + stream_result.client_platform = fs_platform + + chat_agent_mode = "unknown" + chat_outcome = "success" + chat_error_category: str | None = None + chat_span_cm, chat_span = open_chat_request_span( + chat_id=chat_id, + search_space_id=search_space_id, + flow=flow, + request_id=request_id, + turn_id=stream_result.turn_id, + filesystem_mode=fs_mode, + client_platform=fs_platform, + agent_mode=chat_agent_mode, + ) + log_file_contract("turn_start", stream_result) + _perf_log.info( + "[stream_new_chat] filesystem_mode=%s client_platform=%s", + fs_mode, + fs_platform, + ) + log_system_snapshot("stream_new_chat_START") + + from app.services.token_tracking_service import start_turn + + accumulator = start_turn() + + premium_reservation: PremiumReservation | None = None + busy_error_raised = False + + emit_stream_error = partial( + emit_stream_terminal_error, + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + + session = async_session_maker() + # Declared at function scope so SSE-yield join points and the finally + # clause see them on every exit path. + persist_user_task: asyncio.Task[int | None] | None = None + persist_asst_task: asyncio.Task[int | None] | None = None + try: + spawn_set_ai_responding_bg( + chat_id=chat_id, user_id=user_id, background_tasks=_background_tasks + ) + + # --- Block 1: LLM config + capability --- + + requested_llm_config_id = llm_config_id + requires_image_input = bool(user_image_data_urls) + + _t0 = time.perf_counter() + pin_result = await resolve_initial_auto_pin( + session, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + requires_image_input=requires_image_input, + requested_llm_config_id=requested_llm_config_id, + ) + if pin_result.error is not None: + message, error_code, error_kind = pin_result.error + yield emit_stream_error( + message=message, error_kind=error_kind, error_code=error_code + ) + yield streaming_service.format_done() + return + llm_config_id = pin_result.llm_config_id # type: ignore[assignment] + + llm, agent_config, llm_load_error = await load_llm_bundle( + session, config_id=llm_config_id, search_space_id=search_space_id + ) + if llm_load_error: + yield emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _perf_log.info( + "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", + time.perf_counter() - _t0, + llm_config_id, + ) + + capability_error = check_image_input_capability( + user_image_data_urls=user_image_data_urls, agent_config=agent_config + ) + if capability_error is not None: + message, error_code = capability_error + yield emit_stream_error( + message=message, + error_kind="user_error", + error_code=error_code, + ) + yield streaming_service.format_done() + return + + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( + agent_config=agent_config, + user_id=user_id, # type: ignore[arg-type] + ) + if not premium_reservation.allowed: + ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}) + if requested_llm_config_id == 0: + pin_fallback = await resolve_initial_auto_pin( + session, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + requires_image_input=requires_image_input, + requested_llm_config_id=requested_llm_config_id, + ) + if pin_fallback.error is not None: + message, error_code, error_kind = pin_fallback.error + yield emit_stream_error( + message=message, + error_kind=error_kind, + error_code=error_code, + ) + yield streaming_service.format_done() + return + llm_config_id = pin_fallback.llm_config_id # type: ignore[assignment] + ot.add_event( + "model.repin", + { + "repin.reason": "premium_quota_exhausted", + "repin.to_config_id": llm_config_id, + }, + ) + llm, agent_config, llm_load_error = await load_llm_bundle( + session, + config_id=llm_config_id, + search_space_id=search_space_id, + ) + if llm_load_error: + yield emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + premium_reservation = None + # Re-route to free fallback logged via the structured + # stream-error logger so cost/analytics see the auto-switch. + from app.tasks.chat.streaming.errors.classifier import ( + log_chat_stream_error, + ) + + log_chat_stream_error( + flow=flow, + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; " + "auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield emit_stream_error( + message=( + "Buy more tokens to continue with this model, or " + "switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, + ) + yield streaming_service.format_done() + return + + if not llm: + yield emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + # --- Block 2: Spawn concurrent persistence; build pre-stream setup --- + + persist_user_task = spawn_persist_user_task( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_documents=mentioned_documents, + background_tasks=_background_tasks, + ) + persist_asst_task = spawn_persist_assistant_shell_task( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + background_tasks=_background_tasks, + ) + + _t0 = time.perf_counter() + connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( + session, search_space_id=search_space_id + ) + _perf_log.info( + "[stream_new_chat] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + _t0 = time.perf_counter() + checkpointer = await get_chat_checkpointer() + _perf_log.info( + "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) + chat_agent_mode = "multi" if use_multi_agent else "single" + set_agent_mode(chat_span, chat_agent_mode) + + _t0 = time.perf_counter() + agent_factory = ( + create_multi_agent_chat_deep_agent + if use_multi_agent + else create_surfsense_deep_agent + ) + # Build the agent inline. Provider 429s surface through the in-stream + # recovery loop below, which repins the thread to an eligible + # alternative config and rebuilds the agent before the user sees any + # output. + agent = await build_main_agent_for_thread( + agent_factory, + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + ) + _perf_log.info( + "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 + ) + + # --- Block 3: Input assembly --- + + _t0 = time.perf_counter() + assembled = await build_new_chat_input_state( + session, + chat_id=chat_id, + search_space_id=search_space_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_document_ids=mentioned_document_ids, + mentioned_folder_ids=mentioned_folder_ids, + mentioned_documents=mentioned_documents, + needs_history_bootstrap=needs_history_bootstrap, + thread_visibility=visibility, + current_user_display_name=current_user_display_name, + filesystem_mode=fs_mode, + request_id=request_id, + turn_id=stream_result.turn_id, + ) + input_state = assembled.input_state + accepted_folder_ids = assembled.accepted_folder_ids + _perf_log.info( + "[stream_new_chat] History bootstrap + doc/report queries in %.3fs", + time.perf_counter() - _t0, + ) + + # All pre-streaming DB reads done. Commit to release the transaction + # and its ACCESS SHARE locks so we don't block DDL (e.g. migrations) + # for the entire LLM streaming duration. Tools that need DB access + # during streaming start their own short-lived transactions (or use + # isolated sessions). + await session.commit() + # Detach heavy ORM objects (documents with chunks, reports, etc.) + # from the session identity map now that we've extracted what we + # need. Without this they accumulate in memory for the entire + # streaming duration (which can be several minutes). + session.expunge_all() + + _perf_log.info( + "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", + time.perf_counter() - _t_total, + chat_id, + ) + + configurable: dict[str, Any] = { + "thread_id": str(chat_id), + "request_id": request_id or "unknown", + "turn_id": stream_result.turn_id, + } + if checkpoint_id: + configurable["checkpoint_id"] = checkpoint_id + + config = { + "configurable": configurable, + # Effectively uncapped, matching the agent-level ``with_config`` + # default in ``chat_deepagent.create_agent`` and the unbounded + # ``while(true)`` in OpenCode's ``session/processor.ts``. Real + # circuit-breakers live in middleware (``DoomLoopMiddleware``, + # plus ``enable_tool_call_limit`` / ``enable_model_call_limit``). + # The original 25 (and our previous 80 bump) hit users on + # legitimate multi-tool plans. + "recursion_limit": 10_000, + } + + # --- Block 4: First SSE frames --- + + for sse in iter_initial_frames( + streaming_service, turn_id=stream_result.turn_id + ): + yield sse + + # --- Block 5: Persistence join + message-id frames --- + + user_message_id = await await_persist_task( + persist_user_task, + chat_id=chat_id, + turn_id=stream_result.turn_id, + log_label="persist_user_task", + ) + if user_message_id is None: + yield emit_stream_error( + message="We couldn't save your message. Please try again in a moment.", + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + for sse in iter_final_frames(streaming_service): + yield sse + return + + # Emit canonical user message id BEFORE any LLM streaming so the FE + # can rename its optimistic ``msg-user-XXX`` placeholder to + # ``msg-{user_message_id}`` and unlock features gated on a real DB id + # (comments, edit-from-this-message). See B4 in the + # ``sse-based_message_id_handshake`` plan. + yield streaming_service.format_data( + "user-message-id", + {"message_id": user_message_id, "turn_id": stream_result.turn_id}, + ) + + assistant_message_id = await await_persist_task( + persist_asst_task, + chat_id=chat_id, + turn_id=stream_result.turn_id, + log_label="persist_asst_task", + ) + if assistant_message_id is None: + # Genuine DB failure — abort the turn rather than stream into a + # void. The user row is already persisted so the legacy + # ghost-thread gate isn't reopened. + yield emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + for sse in iter_final_frames(streaming_service): + yield sse + return + + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() + + # --- Block 6: Initial thinking step + title task + runtime context --- + + initial_step = build_initial_thinking_step( + user_query=user_query, + user_image_data_urls=user_image_data_urls, + ) + for sse in iter_initial_thinking_step_frame( + initial_step, + streaming_service=streaming_service, + content_builder=stream_result.content_builder, + ): + yield sse + + initial_step_id = initial_step.step_id + initial_step_title = initial_step.title + initial_step_items = initial_step.items + # Drop the heavy ORM objects + the container that holds them so they + # aren't retained for the entire streaming duration. ``input_state`` + # already carries the langchain_messages list independently. + del assembled + + title_task = spawn_title_task( + chat_id=chat_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + assistant_message_id=assistant_message_id, + llm=llm, + agent_config=agent_config, + ) + title_emitted = False + + runtime_context = build_new_chat_runtime_context( + search_space_id=search_space_id, + mentioned_document_ids=mentioned_document_ids, + accepted_folder_ids=accepted_folder_ids, + mentioned_folder_ids=mentioned_folder_ids, + request_id=request_id, + turn_id=stream_result.turn_id, + ) + + # --- Block 7: Stream loop --- + + _t_stream_start = time.perf_counter() + runtime_rate_limit_recovered = False + + def _on_first_event() -> None: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + + async def _recover(exc: BaseException, first_event_seen: bool): + nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered + nonlocal title_task + if not can_recover_provider_rate_limit( + exc, + first_event_seen=first_event_seen, + runtime_rate_limit_recovered=runtime_rate_limit_recovered, + requested_llm_config_id=requested_llm_config_id, + current_llm_config_id=llm_config_id, + ): + return None + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + llm_config_id = await reroute_to_next_auto_pin( + session, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + current_llm_config_id=llm_config_id, + requires_image_input=requires_image_input, + ) + new_llm, new_agent_config, llm_load_err = await load_llm_bundle( + session, config_id=llm_config_id, search_space_id=search_space_id + ) + if llm_load_err: + # Re-raise the original so the terminal-error path classifies + # it correctly (don't swallow as "config load error"). + return None + llm = new_llm + agent_config = new_agent_config + + # Title gen used the initial llm object. After a runtime repin we + # keep the stream focused on response recovery and skip title gen + # for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t_rebuild = time.perf_counter() + new_agent = await build_main_agent_for_thread( + agent_factory, + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t_rebuild, + ) + log_rate_limit_recovered( + flow=flow, + request_id=request_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + previous_config_id=previous_config_id, + new_config_id=llm_config_id, + ) + return new_agent + + async for sse in run_stream_loop( + agent=agent, + streaming_service=streaming_service, + config=config, + input_data=input_state, + stream_result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, + content_builder=stream_result.content_builder, + recover=_recover, + on_first_event=_on_first_event, + ): + yield sse + # Inject the title update mid-stream as soon as the background + # task finishes; gated so we emit at most once. + async for title_sse in maybe_emit_title_update( + title_task=title_task, + title_emitted=title_emitted, + chat_id=chat_id, + accumulator=accumulator, + streaming_service=streaming_service, + ): + yield title_sse + title_emitted = True + # Account for the case where the task completed but produced no + # title — flip the flag anyway so we don't keep checking it. + if title_task is not None and title_task.done() and not title_emitted: + title_emitted = True + + _perf_log.info( + "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", + time.perf_counter() - _t_stream_start, + chat_id, + ) + log_system_snapshot("stream_new_chat_END") + + # --- Block 8: Finalize --- + + if stream_result.is_interrupted: + ot.add_event("chat.interrupted", {"chat.flow": flow}) + if title_task is not None and not title_task.done(): + title_task.cancel() + for sse in iter_token_usage_frame( + streaming_service, + accumulator=accumulator, + log_label="interrupted new_chat", + ): + yield sse + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + async for title_sse in await_pending_title_update( + title_task=title_task, + title_emitted=title_emitted, + chat_id=chat_id, + accumulator=accumulator, + streaming_service=streaming_service, + ): + yield title_sse + + # Finalize premium credit debit with the actual provider cost reported + # by LiteLLM, summed across every call in the turn. Mirrors the + # pre-cost behaviour of "premium turn → all calls count" so free + # sub-agent calls during a premium turn still contribute to the bill + # (they're $0 in practice anyway). + if premium_reservation is not None and user_id: + await finalize_premium( + reservation=premium_reservation, + user_id=user_id, + accumulator=accumulator, + ) + premium_reservation = None + + for sse in iter_token_usage_frame( + streaming_service, accumulator=accumulator, log_label="normal new_chat" + ): + yield sse + + for sse in iter_final_frames(streaming_service): + yield sse + + except Exception as exc: + frames, summary = handle_terminal_exception( + exc, + flow=flow, + flow_label="chat", + log_prefix="stream_new_chat", + streaming_service=streaming_service, + request_id=request_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + chat_span=chat_span, + ) + if summary["busy_error_raised"]: + busy_error_raised = True + chat_outcome = summary["chat_outcome"] + chat_error_category = summary["chat_error_category"] + for sse in frames: + yield sse + + finally: + # Shield the ENTIRE async cleanup from anyio cancel-scope cancellation. + # Starlette's BaseHTTPMiddleware uses anyio task groups; on client + # disconnect, it cancels the scope with level-triggered cancellation + # — every unshielded ``await`` would raise CancelledError immediately. + # Without this the very first ``await`` (session.rollback) would + # raise, ``except Exception`` wouldn't catch it (CancelledError is a + # BaseException), and the rest of cleanup — including session.close() + # — would never run. + with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + + if premium_reservation is not None and user_id: + await release_premium(reservation=premium_reservation, user_id=user_id) + + await close_session_and_clear_ai_responding(session, chat_id) + + await finalize_assistant_message( + stream_result=stream_result, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + accumulator=accumulator, + log_prefix="stream_new_chat", + ) + + # Persist any sandbox-produced files to local storage so they remain + # downloadable after the Daytona sandbox auto-deletes. + if stream_result and stream_result.sandbox_files: + with contextlib.suppress(Exception): + from app.agents.new_chat.sandbox import ( + is_sandbox_enabled, + persist_and_delete_sandbox, + ) + + if is_sandbox_enabled(): + with anyio.CancelScope(shield=True): + await persist_and_delete_sandbox( + chat_id, stream_result.sandbox_files + ) + + # ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout. + # Skip on ``BusyError`` (caller never acquired the lock). + if not busy_error_raised: + with contextlib.suppress(Exception): + end_turn(str(chat_id)) + _perf_log.info( + "[stream_new_chat] end_turn cleanup (chat_id=%s)", chat_id + ) + + # Break circular refs held by the agent graph, tools, and LLM + # wrappers so the GC can reclaim them in a single pass. + agent = llm = connector_service = None + input_state = stream_result = None + session = None + + run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id) + close_chat_request_span( + span_cm=chat_span_cm, + span=chat_span, + chat_outcome=chat_outcome, + chat_agent_mode=chat_agent_mode, + flow=flow, + chat_error_category=chat_error_category, + duration_seconds=time.perf_counter() - _t_total, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py new file mode 100644 index 000000000..9ea5d2ad6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py @@ -0,0 +1,129 @@ +"""Concurrent persistence tasks spawned right after the initial validation gate. + +These run *during* the rest of the pre-stream setup so we don't serialize +their latency against agent construction. Awaiting them at the SSE message-id +yield sites preserves the ghost-thread protection (the user-row INSERT must +succeed before any LLM streaming begins). + +The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own +shielded session — failures only delay the "AI is responding…" UI flag, not +the response itself. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from uuid import UUID + +from app.db import shielded_async_session +from app.services.chat_session_state_service import set_ai_responding +from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, +) + +logger = logging.getLogger(__name__) + + +def spawn_set_ai_responding_bg( + *, + chat_id: int, + user_id: str | None, + background_tasks: set[asyncio.Task[Any]], +) -> None: + """Fire-and-forget: flip the per-thread AI-responding flag on its own session. + + Errors are swallowed and logged — the worst case is a stale UI flag, which + is preferable to delaying the SSE stream behind a flag write. + """ + if not user_id: + return + + async def _bg_set_ai_responding() -> None: + try: + async with shielded_async_session() as s: + await set_ai_responding(s, chat_id, UUID(user_id)) + except Exception: + logger.warning( + "set_ai_responding failed (chat_id=%s)", + chat_id, + exc_info=True, + ) + + t = asyncio.create_task(_bg_set_ai_responding()) + background_tasks.add(t) + t.add_done_callback(background_tasks.discard) + + +def spawn_persist_user_task( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_documents: list[dict[str, Any]] | None, + background_tasks: set[asyncio.Task[Any]], +) -> asyncio.Task[int | None]: + """Spawn the user-row INSERT; await at the user-message-id yield site.""" + task = asyncio.create_task( + persist_user_turn( + chat_id=chat_id, + user_id=user_id, + turn_id=turn_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_documents=mentioned_documents, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + +def spawn_persist_assistant_shell_task( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + background_tasks: set[asyncio.Task[Any]], +) -> asyncio.Task[int | None]: + """Spawn the assistant-shell INSERT; await at the assistant-message-id yield site.""" + task = asyncio.create_task( + persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=turn_id, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + +async def await_persist_task( + task: asyncio.Task[int | None] | None, + *, + chat_id: int, + turn_id: str, + log_label: str, +) -> int | None: + """Join a spawned persistence task with ``shield`` + uniform error handling. + + ``shield`` keeps the DB write alive if the SSE generator is cancelled by + client disconnect mid-await. Returns ``None`` on failure; the caller + abort-paths the turn with a friendly error SSE. + """ + if task is None: + return None + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id + ) + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py new file mode 100644 index 000000000..cf1e8c3fb --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py @@ -0,0 +1,36 @@ +"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn. + +Carries the per-turn read inputs that middlewares read via +``runtime.context.*`` instead of from their ``__init__`` closures, so the same +compiled-agent instance can serve multiple turns with different +mention lists / request ids / turn ids without rebuilding the graph. +""" + +from __future__ import annotations + +from app.agents.new_chat.context import SurfSenseContextSchema + + +def build_new_chat_runtime_context( + *, + search_space_id: int, + mentioned_document_ids: list[int] | None, + accepted_folder_ids: list[int], + mentioned_folder_ids: list[int] | None, + request_id: str | None, + turn_id: str, +) -> SurfSenseContextSchema: + """``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``. + + ``accepted_folder_ids`` (post-resolve) wins over the raw + ``mentioned_folder_ids`` from the request: the resolver drops chips that + pointed at deleted folders or folders the caller can't see, so middlewares + only get authorized ids. + """ + return SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list(mentioned_document_ids or []), + mentioned_folder_ids=list(accepted_folder_ids or mentioned_folder_ids or []), + request_id=request_id, + turn_id=turn_id, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py new file mode 100644 index 000000000..7db45941b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -0,0 +1,233 @@ +"""Background thread-title generation (first-response only). + +The first assistant response in a thread gets a short auto-generated title +inserted into ``new_chat_threads.title``. We: + + 1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with + the agent stream (no extra TTFT). + 2. Probe inside the task (on its own shielded session) whether this is + actually the first response — newer turns short-circuit to ``None``. + 3. Inject the resulting ``thread-title-update`` SSE frame on the first agent + event after the task completes (mid-stream interlock), or right before + the finish frames (post-stream join) if the task hadn't finished yet. + +Usage tokens come directly off the response (LiteLLM's async callback fires +via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would +run too late). We also blank the per-task accumulator so the late callback +doesn't double-count. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from sqlalchemy.future import select + +from app.db import NewChatMessage, NewChatThread, shielded_async_session +from app.prompts import TITLE_GENERATION_PROMPT +from app.services.new_streaming_service import VercelStreamingService + +if TYPE_CHECKING: + from app.agents.new_chat.llm_config import AgentConfig + from app.services.token_tracking_service import TokenAccumulator + + +logger = logging.getLogger(__name__) + + +def spawn_title_task( + *, + chat_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + assistant_message_id: int | None, + llm: Any, + agent_config: AgentConfig | None, +) -> asyncio.Task[tuple[str | None, dict | None]] | None: + """Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met. + + Title gen is gated on a real ``assistant_message_id`` so a stream that + aborts before persistence can never leave a thread with a title and no + anchoring rows. + """ + if assistant_message_id is None: + return None + return asyncio.create_task( + _generate_title( + chat_id=chat_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + assistant_message_id=assistant_message_id, + llm=llm, + agent_config=agent_config, + ) + ) + + +async def _generate_title( + *, + chat_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + assistant_message_id: int, + llm: Any, + agent_config: AgentConfig | None, +) -> tuple[str | None, dict | None]: + """Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``.""" + try: + from litellm import acompletion + + from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base + from app.services.token_tracking_service import _turn_accumulator + + # Excludes this turn's own assistant row (pre-written by + # ``persist_assistant_shell``) — without the ``!=`` filter the gate + # would false-negative on every turn after the first. + try: + async with shielded_async_session() as probe_session: + probe_result = await probe_session.execute( + select(NewChatMessage.id) + .filter( + NewChatMessage.thread_id == chat_id, + NewChatMessage.role == "assistant", + NewChatMessage.id != assistant_message_id, + ) + .limit(1) + ) + is_first_response = probe_result.scalars().first() is None + except Exception: + logger.warning( + "[TitleGen] first-response probe failed (chat_id=%s)", + chat_id, + exc_info=True, + ) + return None, None + + if not is_first_response: + return None, None + + _turn_accumulator.set(None) + + title_seed = user_query.strip() or ( + f"[{len(user_image_data_urls or [])} image(s)]" + if user_image_data_urls + else "" + ) + prompt = TITLE_GENERATION_PROMPT.replace( + "{user_query}", title_seed[:500] or "(message)" + ) + messages = [{"role": "user", "content": prompt}] + + if getattr(llm, "model", None) == "auto": + router = LLMRouterService.get_router() + response = await router.acompletion(model="auto", messages=messages) + else: + # Apply the same ``api_base`` cascade chat / vision / image-gen + # call sites use so we never inherit ``litellm.api_base`` + # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat + # config itself ships an empty ``api_base``. Without this the + # title-gen on an OpenRouter chat config would 404 against the + # inherited Azure endpoint — see ``provider_api_base`` for the + # same bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None + provider_value = agent_config.provider if agent_config is not None else None + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) + response = await acompletion( + model=raw_model, + messages=messages, + api_key=getattr(llm, "api_key", None), + api_base=title_api_base, + ) + + usage_info = None + usage = getattr(response, "usage", None) + if usage: + raw_model = getattr(llm, "model", "") or "" + model_name = ( + raw_model.split("/", 1)[-1] + if "/" in raw_model + else (raw_model or response.model or "unknown") + ) + usage_info = { + "model": model_name, + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + raw_title = response.choices[0].message.content.strip() + if raw_title and len(raw_title) <= 100: + return raw_title.strip("\"'"), usage_info + return None, usage_info + except Exception: + logger.exception("[TitleGen] _generate_title failed") + return None, None + + +async def maybe_emit_title_update( + *, + title_task: asyncio.Task[tuple[str | None, dict | None]] | None, + title_emitted: bool, + chat_id: int, + accumulator: TokenAccumulator, + streaming_service: VercelStreamingService, +): + """Inject one ``thread-title-update`` SSE if the task completed. + + Yields the SSE frame (when applicable). Returns nothing; the orchestrator + flips ``title_emitted`` itself after iterating so we don't fight Python's + nonlocal-in-generator semantics. + """ + if title_task is None or title_emitted or not title_task.done(): + return + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update(chat_id, generated_title) + + +async def await_pending_title_update( + *, + title_task: asyncio.Task[tuple[str | None, dict | None]] | None, + title_emitted: bool, + chat_id: int, + accumulator: TokenAccumulator, + streaming_service: VercelStreamingService, +): + """If the task hadn't completed during the stream, await it now and emit. + + Used right before the finish frames in the success path. Mirror of + ``maybe_emit_title_update`` but unconditionally awaits. + """ + if title_task is None or title_emitted: + return + generated_title, title_usage = await title_task + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update(chat_id, generated_title) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/__init__.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/__init__.py new file mode 100644 index 000000000..ed0683e19 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/__init__.py @@ -0,0 +1,12 @@ +"""Resume-chat streaming flow. + +Public entry point ``stream_resume_chat`` is the slim coroutine in +``orchestrator.py`` that composes the per-concern modules in this folder and +the building blocks under ``flows/shared/``. +""" + +from __future__ import annotations + +from app.tasks.chat.streaming.flows.resume_chat.orchestrator import stream_resume_chat + +__all__ = ["stream_resume_chat"] diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/assistant_shell.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/assistant_shell.py new file mode 100644 index 000000000..2f34387f8 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/assistant_shell.py @@ -0,0 +1,31 @@ +"""Pre-write a fresh assistant row for this resume turn. + +The original (interrupted) ``stream_new_chat`` invocation already persisted +its own assistant row anchored to a different ``turn_id``; resume allocates a +new ``turn_id`` (per-request, see ``orchestrator``) so we need a separate row +keyed on the same ``(thread_id, turn_id, ASSISTANT)`` invariant. + +Idempotent against migration 141's partial unique index — recovers the +existing id on retry. + +Resume does NOT emit ``data-user-message-id``: the user row is from the +original interrupted turn (different ``turn_id``) and is never re-persisted +here. See B5 in the ``sse-based_message_id_handshake`` plan. +""" + +from __future__ import annotations + +from app.tasks.chat.persistence import persist_assistant_shell + + +async def persist_resume_assistant_shell( + *, + chat_id: int, + user_id: str | None, + turn_id: str, +) -> int | None: + return await persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=turn_id, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py new file mode 100644 index 000000000..e1b95aa63 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -0,0 +1,624 @@ +"""``stream_resume_chat`` — public entry point for a HITL resume turn. + +Slim composition layer over the per-concern modules in this folder and the +building blocks under ``flows/shared/``. Mirrors ``stream_new_chat`` but: + + * No user-message persistence (the original turn already wrote it). + * No mentions / surfsense-doc / report context assembly (seeded by original). + * No title generation (only fires on first-response). + * Synchronous ``persist_assistant_shell`` call (we have no other in-flight + pre-stream work to overlap it with). + * ``input_data`` is a ``Command(resume=lg_resume_map)`` instead of a + LangChain message list. +""" + +from __future__ import annotations + +import contextlib +import logging +import time +from collections.abc import AsyncGenerator +from functools import partial +from uuid import UUID + +import anyio + +from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent +from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection +from app.agents.new_chat.middleware.busy_mutex import end_turn +from app.config import config as _app_config +from app.db import ChatVisibility, async_session_maker +from app.observability import otel as ot +from app.services.chat_session_state_service import set_ai_responding +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.content_builder import AssistantContentBuilder +from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread +from app.tasks.chat.streaming.contract.file_contract import log_file_contract +from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error +from app.tasks.chat.streaming.flows.resume_chat.assistant_shell import ( + persist_resume_assistant_shell, +) +from app.tasks.chat.streaming.flows.resume_chat.resume_routing import ( + build_resume_routing, +) +from app.tasks.chat.streaming.flows.resume_chat.runtime_context import ( + build_resume_chat_runtime_context, +) +from app.tasks.chat.streaming.flows.shared.assistant_finalize import ( + finalize_assistant_message, +) +from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame +from app.tasks.chat.streaming.flows.shared.finally_cleanup import ( + close_session_and_clear_ai_responding, + run_gc_pass, +) +from app.tasks.chat.streaming.flows.shared.first_frames import ( + iter_final_frames, + iter_initial_frames, +) +from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle +from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( + get_chat_checkpointer, + setup_connector_and_firecrawl, +) +from app.tasks.chat.streaming.flows.shared.premium_quota import ( + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, +) +from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( + can_recover_provider_rate_limit, + log_rate_limit_recovered, + reroute_to_next_auto_pin, +) +from app.tasks.chat.streaming.flows.shared.span import ( + close_chat_request_span, + open_chat_request_span, + set_agent_mode, +) +from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop +from app.tasks.chat.streaming.flows.shared.terminal_error import ( + handle_terminal_exception, +) +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.tasks.chat.streaming.shared.utils import resume_step_prefix +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +async def stream_resume_chat( + chat_id: int, + search_space_id: int, + decisions: list[dict], + user_id: str | None = None, + llm_config_id: int = -1, + thread_visibility: ChatVisibility | None = None, + filesystem_selection: FilesystemSelection | None = None, + request_id: str | None = None, + disabled_tools: list[str] | None = None, +) -> AsyncGenerator[str, None]: + """Resume a paused HITL turn with the user's decisions. + + Mirrors ``stream_new_chat`` except for the resume-specific routing of + ``decisions`` to per-``tool_call_id`` slices (``build_resume_routing``). + """ + streaming_service = VercelStreamingService() + stream_result = StreamResult() + _t_total = time.perf_counter() + fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud" + fs_platform = ( + filesystem_selection.client_platform.value if filesystem_selection else "web" + ) + stream_result.request_id = request_id + stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}" + stream_result.filesystem_mode = fs_mode + stream_result.client_platform = fs_platform + + chat_agent_mode = "unknown" + chat_outcome = "success" + chat_error_category: str | None = None + chat_span_cm, chat_span = open_chat_request_span( + chat_id=chat_id, + search_space_id=search_space_id, + flow="resume", + request_id=request_id, + turn_id=stream_result.turn_id, + filesystem_mode=fs_mode, + client_platform=fs_platform, + agent_mode=chat_agent_mode, + ) + log_file_contract("turn_start", stream_result) + _perf_log.info( + "[stream_resume] filesystem_mode=%s client_platform=%s", + fs_mode, + fs_platform, + ) + + from app.services.token_tracking_service import start_turn + + accumulator = start_turn() + + premium_reservation: PremiumReservation | None = None + busy_error_raised = False + + emit_stream_error = partial( + emit_stream_terminal_error, + streaming_service=streaming_service, + flow="resume", + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) + + session = async_session_maker() + try: + if user_id: + await set_ai_responding(session, chat_id, UUID(user_id)) + + requested_llm_config_id = llm_config_id + + # --- LLM config --- + + _t0 = time.perf_counter() + try: + from app.services.auto_model_pin_service import ( + resolve_or_get_pinned_llm_config_id, + ) + + pinned = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + llm_config_id = pinned.resolved_llm_config_id + ot.add_event( + "model.pin.resolved", + { + "pin.requested_id": requested_llm_config_id, + "pin.resolved_id": llm_config_id, + "pin.requires_image_input": False, + }, + ) + except ValueError as pin_error: + yield emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await load_llm_bundle( + session, config_id=llm_config_id, search_space_id=search_space_id + ) + if llm_load_error: + yield emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + _perf_log.info( + "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 + ) + + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( + agent_config=agent_config, + user_id=user_id, # type: ignore[arg-type] + ) + if not premium_reservation.allowed: + ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}) + if requested_llm_config_id == 0: + try: + pinned_fb = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + llm_config_id = pinned_fb.resolved_llm_config_id + ot.add_event( + "model.repin", + { + "repin.reason": "premium_quota_exhausted", + "repin.to_config_id": llm_config_id, + }, + ) + except ValueError as pin_error: + yield emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + llm, agent_config, llm_load_error = await load_llm_bundle( + session, + config_id=llm_config_id, + search_space_id=search_space_id, + ) + if llm_load_error: + yield emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + premium_reservation = None + from app.tasks.chat.streaming.errors.classifier import ( + log_chat_stream_error, + ) + + log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; " + "auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, + ) + else: + yield emit_stream_error( + message=( + "Buy more tokens to continue with this model, or " + "switch to a free model" + ), + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, + ) + yield streaming_service.format_done() + return + + if not llm: + yield emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + # --- Pre-stream setup --- + + _t0 = time.perf_counter() + connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( + session, search_space_id=search_space_id + ) + _perf_log.info( + "[stream_resume] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, + ) + + _t0 = time.perf_counter() + checkpointer = await get_chat_checkpointer() + _perf_log.info( + "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + visibility = thread_visibility or ChatVisibility.PRIVATE + use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED) + chat_agent_mode = "multi" if use_multi_agent else "single" + set_agent_mode(chat_span, chat_agent_mode) + + _t0 = time.perf_counter() + agent_factory = ( + create_multi_agent_chat_deep_agent + if use_multi_agent + else create_surfsense_deep_agent + ) + agent = await build_main_agent_for_thread( + agent_factory, + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + ) + _perf_log.info( + "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 + ) + + # Release the transaction before streaming (same rationale as stream_new_chat). + await session.commit() + session.expunge_all() + + _perf_log.info( + "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", + time.perf_counter() - _t_total, + chat_id, + ) + + # --- Resume routing --- + + from langgraph.types import Command + + routing = await build_resume_routing( + agent, chat_id=chat_id, decisions=decisions + ) + + config = { + "configurable": { + "thread_id": str(chat_id), + "request_id": request_id or "unknown", + "turn_id": stream_result.turn_id, + # Per-``tool_call_id`` resume slices read by + # ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel + # siblings each pop their own entry, so they never race. + "surfsense_resume_value": routing.routed_resume_value, + }, + # Same rationale as ``stream_new_chat``: effectively uncapped to + # mirror the agent default and OpenCode's session loop. Doom-loop + # / call-limit middleware enforce the real ceiling. + "recursion_limit": 10_000, + } + + # --- First SSE frames --- + + for sse in iter_initial_frames( + streaming_service, turn_id=stream_result.turn_id + ): + yield sse + + # --- Assistant-shell persistence + id frame --- + + assistant_message_id = await persist_resume_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + ) + if assistant_message_id is None: + yield emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + for sse in iter_final_frames(streaming_service): + yield sse + return + + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() + + runtime_context = build_resume_chat_runtime_context( + search_space_id=search_space_id, + request_id=request_id, + turn_id=stream_result.turn_id, + ) + + # --- Stream loop --- + + _t_stream_start = time.perf_counter() + runtime_rate_limit_recovered = False + + def _on_first_event() -> None: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + + async def _recover(exc: BaseException, first_event_seen: bool): + nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered + if not can_recover_provider_rate_limit( + exc, + first_event_seen=first_event_seen, + runtime_rate_limit_recovered=runtime_rate_limit_recovered, + requested_llm_config_id=requested_llm_config_id, + current_llm_config_id=llm_config_id, + ): + return None + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + llm_config_id = await reroute_to_next_auto_pin( + session, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + current_llm_config_id=llm_config_id, + requires_image_input=False, + ) + new_llm, new_agent_config, llm_load_err = await load_llm_bundle( + session, config_id=llm_config_id, search_space_id=search_space_id + ) + if llm_load_err: + return None + llm = new_llm + agent_config = new_agent_config + + _t_rebuild = time.perf_counter() + new_agent = await build_main_agent_for_thread( + agent_factory, + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t_rebuild, + ) + log_rate_limit_recovered( + flow="resume", + request_id=request_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + previous_config_id=previous_config_id, + new_config_id=llm_config_id, + ) + return new_agent + + async for sse in run_stream_loop( + agent=agent, + streaming_service=streaming_service, + config=config, + input_data=Command(resume=routing.lg_resume_map), + stream_result=stream_result, + step_prefix=resume_step_prefix(stream_result.turn_id), + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + runtime_context=runtime_context, + content_builder=stream_result.content_builder, + recover=_recover, + on_first_event=_on_first_event, + ): + yield sse + + _perf_log.info( + "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", + time.perf_counter() - _t_stream_start, + chat_id, + ) + + # --- Finalize --- + + if stream_result.is_interrupted: + ot.add_event("chat.interrupted", {"chat.flow": "resume"}) + for sse in iter_token_usage_frame( + streaming_service, + accumulator=accumulator, + log_label="interrupted resume_chat", + ): + yield sse + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + if premium_reservation is not None and user_id: + await finalize_premium( + reservation=premium_reservation, + user_id=user_id, + accumulator=accumulator, + ) + premium_reservation = None + + for sse in iter_token_usage_frame( + streaming_service, accumulator=accumulator, log_label="normal resume_chat" + ): + yield sse + + for sse in iter_final_frames(streaming_service): + yield sse + + except Exception as exc: + frames, summary = handle_terminal_exception( + exc, + flow="resume", + flow_label="resume", + log_prefix="stream_resume_chat", + streaming_service=streaming_service, + request_id=request_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + chat_span=chat_span, + ) + if summary["busy_error_raised"]: + busy_error_raised = True + chat_outcome = summary["chat_outcome"] + chat_error_category = summary["chat_error_category"] + for sse in frames: + yield sse + + finally: + with anyio.CancelScope(shield=True): + end_turn(str(chat_id)) + + if premium_reservation is not None and user_id: + await release_premium(reservation=premium_reservation, user_id=user_id) + + await close_session_and_clear_ai_responding(session, chat_id) + + await finalize_assistant_message( + stream_result=stream_result, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + accumulator=accumulator, + log_prefix="stream_resume", + ) + + # Release the lock from the original interrupted turn or any + # re-interrupt/bailout. Skip on ``BusyError`` (lock not held here). + if not busy_error_raised: + with contextlib.suppress(Exception): + end_turn(str(chat_id)) + _perf_log.info("[stream_resume] end_turn cleanup (chat_id=%s)", chat_id) + + agent = llm = connector_service = None + stream_result = None + session = None + + run_gc_pass(log_prefix="stream_resume", chat_id=chat_id) + close_chat_request_span( + span_cm=chat_span_cm, + span=chat_span, + chat_outcome=chat_outcome, + chat_agent_mode=chat_agent_mode, + flow="resume", + chat_error_category=chat_error_category, + duration_seconds=time.perf_counter() - _t_total, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py new file mode 100644 index 000000000..7f4f67aac --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/resume_routing.py @@ -0,0 +1,63 @@ +"""Route a flat ``decisions`` list back to the right paused subagent. + +Each pending interrupt is stamped with its originating ``tool_call_id`` (see +``checkpointed_subagent_middleware.propagation``) so the resume slicer can +re-target each ``HumanReview`` decision at the right ``tool_call_id``. + +LangGraph rejects scalar ``Command(resume=...)`` when multiple interrupts are +pending (parallel HITL); the mapped form works for the single-pause case too, +so we always use it. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +from app.utils.perf import get_perf_logger + +_perf_log = get_perf_logger() +logger = logging.getLogger(__name__) + + +@dataclass +class ResumeRoutingPayload: + """Resolved per-``tool_call_id`` resume slices + the lg-shaped resume map.""" + + routed_resume_value: dict[str, Any] + lg_resume_map: dict[str, Any] + + +async def build_resume_routing( + agent: Any, + *, + chat_id: int, + decisions: list[dict], +) -> ResumeRoutingPayload: + """Read parent_state, collect pending tool-calls, slice decisions, build map. + + The middleware reads its per-``tool_call_id`` resume slice from the + ``surfsense_resume_value`` configurable; parallel siblings each pop their + own entry so they never race. + """ + from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, + ) + + parent_state = await agent.aget_state({"configurable": {"thread_id": str(chat_id)}}) + pending = collect_pending_tool_calls(parent_state) + _perf_log.info( + "[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d", + chat_id, + len(decisions), + len(pending), + ) + routed_resume_value = slice_decisions_by_tool_call(decisions, pending) + lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value) + return ResumeRoutingPayload( + routed_resume_value=routed_resume_value, + lg_resume_map=lg_resume_map, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py new file mode 100644 index 000000000..59d5d8ca7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/runtime_context.py @@ -0,0 +1,23 @@ +"""Build the per-invocation ``SurfSenseContextSchema`` for a resume turn. + +Resume doesn't carry new ``mentioned_document_ids`` (those are seeded by the +original turn). We still build the context so future middleware extensions +can rely on ``runtime.context`` always being populated. +""" + +from __future__ import annotations + +from app.agents.new_chat.context import SurfSenseContextSchema + + +def build_resume_chat_runtime_context( + *, + search_space_id: int, + request_id: str | None, + turn_id: str, +) -> SurfSenseContextSchema: + return SurfSenseContextSchema( + search_space_id=search_space_id, + request_id=request_id, + turn_id=turn_id, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/__init__.py new file mode 100644 index 000000000..b65acc43c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/__init__.py @@ -0,0 +1,3 @@ +"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py new file mode 100644 index 000000000..be1f102f3 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -0,0 +1,107 @@ +"""Server-side assistant-message + token_usage finalization. + +Runs inside the streaming flow's ``finally`` block, after the main session has +been closed (uses its own shielded session, so we don't fight the same DB +connection). + +Idempotent against the legacy frontend ``appendMessage`` recovery branch: + + * the assistant row was already INSERTed by ``persist_assistant_shell`` + earlier in the turn, so this just UPDATEs it with the rich + ``ContentPart[]`` projection from the builder. + * ``token_usage`` uses ``INSERT ... ON CONFLICT DO NOTHING`` against the + partial unique index from migration 142, so a racing append_message + recovery branch can never double-write. + +``mark_interrupted`` closes any open text/reasoning blocks and flips running +tool-calls (no result) to ``state=aborted`` so the persisted JSONB reflects a +coherent end-state even on client disconnect. + +Never raises (best-effort, logs only). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.utils.perf import get_perf_logger + +if TYPE_CHECKING: + from app.services.token_tracking_service import TokenAccumulator + +_perf_log = get_perf_logger() + + +async def finalize_assistant_message( + *, + stream_result: StreamResult | None, + chat_id: int, + search_space_id: int, + user_id: str | None, + accumulator: TokenAccumulator, + log_prefix: str, +) -> None: + """Snapshot the content builder and persist the final assistant payload. + + No-op when ``stream_result`` was never populated, the turn never reached + ``persist_assistant_shell`` (no ``assistant_message_id``), or the turn id + was never assigned. + """ + if not ( + stream_result and stream_result.turn_id and stream_result.assistant_message_id + ): + return + + from app.tasks.chat.persistence import finalize_assistant_turn + + builder_stats: dict[str, int] | None = None + if stream_result.content_builder is not None: + stream_result.content_builder.mark_interrupted() + # Snapshot stats BEFORE ``snapshot()`` deepcopies so the perf log + # records the actual finalised payload (post-mark_interrupted), not + # the live-mutating builder state. + builder_stats = stream_result.content_builder.stats() + content_payload = stream_result.content_builder.snapshot() + else: + # Defensive fallback — we always set the builder alongside + # ``assistant_message_id`` in the orchestrator, so this branch only + # fires if a future refactor ever decouples them. Persist whatever + # accumulated text we captured so the row at least renders. + content_payload = [ + { + "type": "text", + "text": stream_result.accumulated_text or "", + } + ] + + if builder_stats is not None: + _perf_log.info( + "[%s] finalize_payload chat_id=%s " + "message_id=%s parts=%d bytes=%d text=%d " + "reasoning=%d tool_calls=%d " + "tool_calls_completed=%d tool_calls_aborted=%d " + "thinking_step_parts=%d step_separators=%d", + log_prefix, + chat_id, + stream_result.assistant_message_id, + builder_stats["parts"], + builder_stats["bytes"], + builder_stats["text"], + builder_stats["reasoning"], + builder_stats["tool_calls"], + builder_stats["tool_calls_completed"], + builder_stats["tool_calls_aborted"], + builder_stats["thinking_step_parts"], + builder_stats["step_separators"], + ) + + await finalize_assistant_turn( + message_id=stream_result.assistant_message_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + turn_id=stream_result.turn_id, + content=content_payload, + accumulator=accumulator, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/finalize_emit.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/finalize_emit.py new file mode 100644 index 000000000..e5de3f6a4 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/finalize_emit.py @@ -0,0 +1,54 @@ +"""Emit the per-turn token-usage SSE frame from the accumulator. + +``per_message_summary()`` returns ``None`` when the turn made no chargeable +LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the +frontend has no usage to render. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from app.services.new_streaming_service import VercelStreamingService +from app.utils.perf import get_perf_logger + +if TYPE_CHECKING: + from app.services.token_tracking_service import TokenAccumulator + +_perf_log = get_perf_logger() +logger = logging.getLogger(__name__) + + +def iter_token_usage_frame( + streaming_service: VercelStreamingService, + *, + accumulator: TokenAccumulator, + log_label: str, +): + """Yield zero or one ``data: token-usage`` SSE frame. + + Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so + cost analysis can grep call/total/cost across all flows. + """ + usage_summary = accumulator.per_message_summary() + _perf_log.info( + "[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s", + log_label, + len(accumulator.calls), + accumulator.grand_total, + accumulator.total_cost_micros, + usage_summary, + ) + if usage_summary: + yield streaming_service.format_data( + "token-usage", + { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "cost_micros": accumulator.total_cost_micros, + "call_details": accumulator.serialized_calls(), + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/finally_cleanup.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/finally_cleanup.py new file mode 100644 index 000000000..f9454775e --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/finally_cleanup.py @@ -0,0 +1,67 @@ +"""Shared finally-block helpers: session close, GC pass, native-heap trim. + +These are called from inside an ``anyio.CancelScope(shield=True)`` block in +each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on +client disconnect; without the shield the very first ``await`` would raise +``CancelledError`` and the rest of cleanup — including ``session.close()`` — +would never run). +""" + +from __future__ import annotations + +import contextlib +import gc +import logging + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import shielded_async_session +from app.services.chat_session_state_service import clear_ai_responding +from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap + +_perf_log = get_perf_logger() +logger = logging.getLogger(__name__) + + +async def close_session_and_clear_ai_responding( + session: AsyncSession, chat_id: int +) -> None: + """Rollback + clear AI-responding flag + expunge_all + close. + + On rollback failure we fall back to a fresh shielded session for the flag + clear so a UI is never stuck on "AI is responding…" after a crash. + """ + try: + await session.rollback() + await clear_ai_responding(session, chat_id) + except Exception: + try: + async with shielded_async_session() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logger.warning("Failed to clear AI responding state for thread %s", chat_id) + + with contextlib.suppress(Exception): + session.expunge_all() + + with contextlib.suppress(Exception): + await session.close() + + +def run_gc_pass(*, log_prefix: str, chat_id: int) -> None: + """One full gen0/1/2 pass + native-heap trim + END system snapshot. + + Breaking circular refs held by the agent graph, tools, and LLM wrappers + needs to happen in the caller (set the locals to ``None``) — this just + runs the collector and logs how many objects came back. + """ + collected = gc.collect(0) + gc.collect(1) + gc.collect(2) + if collected: + _perf_log.info( + "[%s] gc.collect() reclaimed %d objects (chat_id=%s)", + log_prefix, + collected, + chat_id, + ) + trim_native_heap() + log_system_snapshot(f"{log_prefix}_END") diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/first_frames.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/first_frames.py new file mode 100644 index 000000000..5e568b1e8 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/first_frames.py @@ -0,0 +1,40 @@ +"""Initial SSE frames every flow emits right after pre-stream setup. + +Order matters: ``message_start`` opens the assistant message, ``start_step`` +opens the first thinking step, ``turn-info`` lets the frontend stamp the +correlation id onto the in-flight message, and ``turn-status: busy`` flips the +UI into the streaming state. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.services.new_streaming_service import VercelStreamingService + + +def iter_initial_frames( + streaming_service: VercelStreamingService, + *, + turn_id: str, +) -> Iterator[str]: + """Yield the four canonical opening frames in order. + + ``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which + never produce a tool / action-log event) still teach the frontend the + turn correlation id used for ``appendMessage`` durable storage. + """ + yield streaming_service.format_message_start() + yield streaming_service.format_start_step() + yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id}) + yield streaming_service.format_data("turn-status", {"status": "busy"}) + + +def iter_final_frames( + streaming_service: VercelStreamingService, +) -> Iterator[str]: + """Yield ``turn-status: idle`` plus the finish/done trailer in order.""" + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py new file mode 100644 index 000000000..2f334114c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -0,0 +1,57 @@ +"""Load an LLM + AgentConfig bundle for a given config id. + +Handles both code paths uniformly: +- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). +- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). + +Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is +``None``. The caller emits the friendly SSE error frame. +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.llm_config import ( + AgentConfig, + create_chat_litellm_from_agent_config, + create_chat_litellm_from_config, + load_agent_config, + load_global_llm_config_by_id, +) + + +async def load_llm_bundle( + session: AsyncSession, + *, + config_id: int, + search_space_id: int, +) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py new file mode 100644 index 000000000..ec92306dd --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/pre_stream_setup.py @@ -0,0 +1,40 @@ +"""Pre-stream setup: connector service, firecrawl key, checkpointer.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.checkpointer import get_checkpointer +from app.db import SearchSourceConnectorType +from app.services.connector_service import ConnectorService + + +async def setup_connector_and_firecrawl( + session: AsyncSession, + *, + search_space_id: int, +) -> tuple[ConnectorService, str | None]: + """Build the per-turn connector service and pull the firecrawl API key. + + Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is + ``None`` when no web-crawler connector is configured (the agent simply + skips firecrawl-backed tools in that case). + """ + connector_service = ConnectorService(session, search_space_id=search_space_id) + firecrawl_api_key: str | None = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + return connector_service, firecrawl_api_key + + +async def get_chat_checkpointer(): + """Resolve the PostgreSQL checkpointer for persistent conversation memory. + + Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so + flow orchestrators can rely on a streaming-local symbol and we have a hook + point if the checkpointer source ever needs to vary per flow. + """ + return await get_checkpointer() diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py new file mode 100644 index 000000000..cbf44764c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py @@ -0,0 +1,132 @@ +"""Premium credit (USD micro-units) reserve / finalize / release lifecycle. + +Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up +front (so a single LLM call can't run away with the budget), then finalize the +actual provider cost reported by LiteLLM when the turn completes successfully, +or release the reservation on the cancellation / interrupted-without-finalize +paths. + +State is held by the orchestrator as a simple ``PremiumReservation`` tuple +so reservation, fallback-on-denied, finalize, and release can all be reasoned +about from one place. +""" + +from __future__ import annotations + +import logging +import uuid as _uuid +from dataclasses import dataclass +from typing import TYPE_CHECKING +from uuid import UUID + +from app.agents.new_chat.llm_config import AgentConfig +from app.db import shielded_async_session + +if TYPE_CHECKING: + from app.services.token_tracking_service import TokenAccumulator + + +@dataclass +class PremiumReservation: + """Active premium-credit reservation for one turn. + + ``request_id`` is the per-reservation idempotency key (also passed to + ``finalize``/``release`` so racing branches resolve to the same row). + ``reserved_micros`` is the up-front estimate; ``finalize`` debits the + actual cost, ``release`` returns it untouched. + """ + + request_id: str + reserved_micros: int + allowed: bool + + +def needs_premium_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool: + return bool(agent_config is not None and user_id and agent_config.is_premium) + + +async def reserve_premium( + *, + agent_config: AgentConfig, + user_id: str, +) -> PremiumReservation: + """Reserve estimated micros up front; returns the reservation handle.""" + from app.services.token_quota_service import ( + TokenQuotaService, + estimate_call_reserve_micros, + ) + + request_id = _uuid.uuid4().hex[:16] + litellm_params = agent_config.litellm_params or {} + base_model = ( + (litellm_params.get("base_model") if isinstance(litellm_params, dict) else None) + or agent_config.model_name + or "" + ) + reserve_amount_micros = estimate_call_reserve_micros( + base_model=base_model, + quota_reserve_tokens=agent_config.quota_reserve_tokens, + ) + async with shielded_async_session() as quota_session: + quota_result = await TokenQuotaService.premium_reserve( + db_session=quota_session, + user_id=UUID(user_id), + request_id=request_id, + reserve_micros=reserve_amount_micros, + ) + return PremiumReservation( + request_id=request_id, + reserved_micros=reserve_amount_micros, + allowed=quota_result.allowed, + ) + + +async def finalize_premium( + *, + reservation: PremiumReservation, + user_id: str, + accumulator: TokenAccumulator, +) -> None: + """Finalize debit using the actual provider cost reported by LiteLLM. + + Best-effort: failures here must not bubble up to the SSE stream — the user + has already received their tokens; we log and move on. + """ + try: + from app.services.token_quota_service import TokenQuotaService + + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_finalize( + db_session=quota_session, + user_id=UUID(user_id), + request_id=reservation.request_id, + actual_micros=accumulator.total_cost_micros, + reserved_micros=reservation.reserved_micros, + ) + except Exception: + logging.getLogger(__name__).warning( + "Failed to finalize premium quota for user %s", + user_id, + exc_info=True, + ) + + +async def release_premium( + *, + reservation: PremiumReservation, + user_id: str, +) -> None: + """Release the reservation on cancellation paths; never raises.""" + try: + from app.services.token_quota_service import TokenQuotaService + + async with shielded_async_session() as quota_session: + await TokenQuotaService.premium_release( + db_session=quota_session, + user_id=UUID(user_id), + reserved_micros=reservation.reserved_micros, + ) + except Exception: + logging.getLogger(__name__).warning( + "Failed to release premium quota for user %s", user_id + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py new file mode 100644 index 000000000..6b3857594 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py @@ -0,0 +1,129 @@ +"""Shared steps for the in-stream provider rate-limit recovery loop. + +Both flows wrap ``run_stream_loop`` with a flow-specific ``recover`` closure; +the *guard*, the *auto-pin reroute*, and the *post-recovery telemetry* are the +same on both sides and live here so behaviour can't drift. + +The orchestrator owns the parts that genuinely diverge: + + * cancelling the title task (new_chat only), + * passing ``mentioned_document_ids`` to ``build_main_agent_for_thread``, + * the log prefix (``stream_new_chat`` vs ``stream_resume``). +""" + +from __future__ import annotations + +from typing import Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.middleware.busy_mutex import end_turn +from app.observability import otel as ot +from app.services.auto_model_pin_service import ( + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) +from app.tasks.chat.streaming.errors.classifier import ( + is_provider_rate_limited, + log_chat_stream_error, +) + + +def can_recover_provider_rate_limit( + exc: BaseException, + *, + first_event_seen: bool, + runtime_rate_limit_recovered: bool, + requested_llm_config_id: int, + current_llm_config_id: int, +) -> bool: + """Guard: only the first auto-pin → provider-rate-limited failure recovers. + + All conditions must hold: + + * ``runtime_rate_limit_recovered is False`` — at most one recovery per turn. + * ``requested_llm_config_id == 0`` — caller opted into auto-pin (id=0). + * ``current_llm_config_id < 0`` — currently on a YAML config (the only + kind the auto-pin pool draws from). + * ``first_event_seen is False`` — we haven't sent any SSE to the user yet, + so a silent rebuild + retry is invisible. + * The exception is provider-side rate-limited (HTTP 429 or known shape). + """ + return ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and current_llm_config_id < 0 + and not first_event_seen + and is_provider_rate_limited(exc) + ) + + +async def reroute_to_next_auto_pin( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_id: str | None, + current_llm_config_id: int, + requires_image_input: bool, +) -> int: + """Release lock, cool down the failing config, pick a new auto-pin id. + + Returns the new ``llm_config_id``. ``end_turn`` is called because the failed + attempt may still hold the per-thread busy mutex (middleware teardown can + lag behind raised provider errors) — the same-request retry would otherwise + bounce on ``BusyError``. + """ + end_turn(str(chat_id)) + mark_runtime_cooldown(current_llm_config_id, reason="provider_rate_limited") + pinned = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={current_llm_config_id}, + requires_image_input=requires_image_input, + ) + return pinned.resolved_llm_config_id + + +def log_rate_limit_recovered( + *, + flow: Literal["new", "regenerate", "resume"], + request_id: str | None, + chat_id: int, + search_space_id: int, + user_id: str | None, + previous_config_id: int, + new_config_id: int, +) -> None: + """Emit the OTEL event + structured ``[chat_stream_error]`` log line.""" + ot.add_event( + "chat.rate_limit.recovered", + { + "recovery.reason": "provider_rate_limited", + "recovery.previous_config_id": previous_config_id, + "recovery.fallback_config_id": new_config_id, + }, + ) + log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": new_config_id, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/span.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/span.py new file mode 100644 index 000000000..74b9682ed --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/span.py @@ -0,0 +1,79 @@ +"""OpenTelemetry chat-request span wrapper for streaming flows.""" + +from __future__ import annotations + +import contextlib +import sys +from typing import Any, Literal + +from app.observability import metrics as ot_metrics, otel as ot + + +def open_chat_request_span( + *, + chat_id: int, + search_space_id: int, + flow: Literal["new", "regenerate", "resume"], + request_id: str | None, + turn_id: str, + filesystem_mode: str, + client_platform: str, + agent_mode: str, +) -> tuple[Any, Any]: + """Open the per-request span; returns ``(span_cm, span)`` for finally-close.""" + span_cm = ot.chat_request_span( + chat_id=chat_id, + search_space_id=search_space_id, + flow=flow, + request_id=request_id, + turn_id=turn_id, + filesystem_mode=filesystem_mode, + client_platform=client_platform, + agent_mode=agent_mode, + ) + span = span_cm.__enter__() + return span_cm, span + + +def set_agent_mode(span: Any, agent_mode: str) -> None: + """Tag the span with the resolved agent mode (single / multi).""" + with contextlib.suppress(Exception): + span.set_attribute("agent.mode", agent_mode) + + +def close_chat_request_span( + *, + span_cm: Any, + span: Any, + chat_outcome: str, + chat_agent_mode: str, + flow: Literal["new", "regenerate", "resume"], + chat_error_category: str | None, + duration_seconds: float, +) -> None: + """Record metrics + close the span. Swallows errors (finally-block context).""" + with contextlib.suppress(Exception): + span.set_attribute("chat.outcome", chat_outcome) + ot_metrics.record_chat_request_duration( + duration_seconds * 1000, + flow=flow, + outcome=chat_outcome, + agent_mode=chat_agent_mode, + ) + ot_metrics.record_chat_request_outcome( + flow=flow, + outcome=chat_outcome, + agent_mode=chat_agent_mode, + error_category=chat_error_category, + ) + span_cm.__exit__(*sys.exc_info()) + + +def record_outcome_attrs( + span: Any, *, chat_outcome: str, chat_error_category: str | None +) -> None: + """Stamp outcome + error.category on the span (used in the except branch).""" + with contextlib.suppress(Exception): + span.set_attribute("chat.outcome", chat_outcome) + if chat_error_category is not None: + span.set_attribute("error.category", chat_error_category) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py new file mode 100644 index 000000000..6cf0df855 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py @@ -0,0 +1,85 @@ +"""Drive ``stream_agent_events`` with in-stream rate-limit recovery. + +Both ``stream_new_chat`` and ``stream_resume_chat`` wrap the agent event loop +in a ``while True`` that catches the *first* provider rate-limit error +(``can_runtime_recover``) before any SSE event reaches the user, rebuilds the +agent on an alternative auto-pin, and retries the turn. + +The recovery callback is flow-specific (different ``mentioned_document_ids`` +contract, different logging label, etc.) — this module owns the loop shape, +the caller owns the rebuild. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.agent.event_loop import stream_agent_events +from app.tasks.chat.streaming.shared.stream_result import StreamResult + +# Returns the rebuilt agent on a successful recovery, or ``None`` to re-raise +# the original exception (and let the orchestrator's terminal-error path +# handle it). +RecoverFn = Callable[[BaseException, bool], Awaitable[Any | None]] + + +async def run_stream_loop( + *, + agent: Any, + streaming_service: VercelStreamingService, + config: dict[str, Any], + input_data: Any, + stream_result: StreamResult, + step_prefix: str = "thinking", + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + fallback_commit_search_space_id: int | None, + fallback_commit_created_by_id: str | None, + fallback_commit_filesystem_mode: FilesystemMode, + fallback_commit_thread_id: int | None, + runtime_context: Any, + content_builder: Any | None, + recover: RecoverFn, + on_first_event: Callable[[], None] | None = None, +) -> AsyncGenerator[str, None]: + """Yield SSE frames; rebuild and retry once on a pre-first-event rate limit. + + ``on_first_event`` fires after the first frame is observed (used by both + flows to write a one-time ``First agent event in N.NNNs`` perf line). + """ + first_event_logged = False + while True: + try: + async for sse in stream_agent_events( + agent=agent, + config=config, + input_data=input_data, + streaming_service=streaming_service, + result=stream_result, + step_prefix=step_prefix, + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + fallback_commit_search_space_id=fallback_commit_search_space_id, + fallback_commit_created_by_id=fallback_commit_created_by_id, + fallback_commit_filesystem_mode=fallback_commit_filesystem_mode, + fallback_commit_thread_id=fallback_commit_thread_id, + runtime_context=runtime_context, + content_builder=content_builder, + ): + if not first_event_logged: + if on_first_event is not None: + on_first_event() + first_event_logged = True + yield sse + return + except Exception as exc: + new_agent = await recover(exc, first_event_logged) + if new_agent is None: + raise + agent = new_agent + continue diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py new file mode 100644 index 000000000..b305dba23 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/terminal_error.py @@ -0,0 +1,119 @@ +"""Handle the ``except Exception`` branch of a streaming flow. + +Classifies the exception, records OpenTelemetry attributes, emits one terminal +error SSE frame and the trailing ``turn-status: idle`` + finish/done frames. + +Used by both ``stream_new_chat`` and ``stream_resume_chat``; flow-specific bits +(label, span, BusyError tracking) are passed by the caller. +""" + +from __future__ import annotations + +import logging +import traceback +from collections.abc import Iterator +from typing import Any, Literal + +from app.agents.new_chat.errors import BusyError +from app.observability import metrics as ot_metrics, otel as ot +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.errors.classifier import classify_stream_exception +from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error +from app.tasks.chat.streaming.flows.shared.first_frames import iter_final_frames +from app.tasks.chat.streaming.flows.shared.span import record_outcome_attrs + +logger = logging.getLogger(__name__) + + +def handle_terminal_exception( + exc: Exception, + *, + flow: Literal["new", "regenerate", "resume"], + flow_label: str, + log_prefix: str, + streaming_service: VercelStreamingService, + request_id: str | None, + chat_id: int, + search_space_id: int, + user_id: str | None, + chat_span: Any, +) -> tuple[Iterator[str], dict[str, Any]]: + """Classify, log, and produce the SSE frames for a terminal exception. + + Returns ``(frame_iterator, summary)``. ``summary`` carries:: + + - ``busy_error_raised``: bool — caller must skip the lock-release path + when True (caller never acquired the busy mutex). + - ``chat_outcome``: str — span outcome attribute. + - ``chat_error_category``: str — categorized error label for metrics. + """ + busy_error_raised = isinstance(exc, BusyError) + + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + error_extra, + ) = classify_stream_exception(exc, flow_label=flow_label) + chat_outcome = error_code or error_kind or "error" + chat_error_category = ot_metrics.categorize_exception(exc) + record_outcome_attrs( + chat_span, + chat_outcome=chat_outcome, + chat_error_category=chat_error_category, + ) + with __suppress(): + ot.record_error(chat_span, exc) + error_message = f"Error during {flow_label}: {exc!s}" + # Match the original behavior: log full traceback via ``print`` so it lands + # in stderr regardless of the logger config. + print(f"[{log_prefix}] {error_message}") + print(f"[{log_prefix}] Exception type: {type(exc).__name__}") + print(f"[{log_prefix}] Traceback:\n{traceback.format_exc()}") + + def _iter_frames() -> Iterator[str]: + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) + + yield emit_stream_terminal_error( + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + extra=error_extra, + ) + yield from iter_final_frames(streaming_service) + + return ( + _iter_frames(), + { + "busy_error_raised": busy_error_raised, + "chat_outcome": chat_outcome, + "chat_error_category": chat_error_category, + }, + ) + + +def __suppress(): + """Local single-use ``contextlib.suppress(Exception)`` factory. + + Inlined here so callers don't import ``contextlib`` just for the + ``record_error`` call site. + """ + import contextlib + + return contextlib.suppress(Exception) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py index ad4a17d08..2ff810447 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py @@ -6,6 +6,9 @@ import json from collections.abc import Iterator from typing import Any +from langchain_core.messages import ToolMessage +from langgraph.types import Command + from app.tasks.chat.streaming.handlers.tools import ( ToolCompletionEmissionContext, iter_tool_completion_emission_frames, @@ -19,6 +22,38 @@ from app.tasks.chat.streaming.relay.task_span import ( from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame +def _unwrap_command_output(raw_output: Any) -> Any: + """Replace a ``Command`` from a tool return with its inner ``ToolMessage``. + + Tools that participate in receipt-style state writes (see + ``app.agents.shared.receipt_command.with_receipt``) return a + ``Command(update={"messages": [ToolMessage(...)], "receipts": [...]})``. + LangChain's ``on_tool_end`` event surfaces that ``Command`` verbatim as + ``data.output``, which the rest of this handler can't introspect: it has + no ``.content``, isn't a ``dict``, and stringifies to ``"Command(...)"``. + That stringified payload reaches the frontend and breaks tool-specific + UI components (e.g. the podcast card) that look for ``status`` / + ``podcast_id`` at the top level. + + We extract the first ``ToolMessage`` from the Command's ``messages`` list + so downstream code can read ``.content`` normally. Commands that don't + contain a ``ToolMessage`` (rare, e.g. pure state updates) are returned + unchanged — the existing ``str(raw_output)`` fallback handles them. + """ + if not isinstance(raw_output, Command): + return raw_output + update = raw_output.update + if not isinstance(update, dict): + return raw_output + messages = update.get("messages") + if not isinstance(messages, list): + return raw_output + for msg in messages: + if isinstance(msg, ToolMessage): + return msg + return raw_output + + def iter_tool_end_frames( event: dict[str, Any], *, @@ -33,7 +68,7 @@ def iter_tool_end_frames( state.active_tool_depth = max(0, state.active_tool_depth - 1) run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") - raw_output = event.get("data", {}).get("output", "") + raw_output = _unwrap_command_output(event.get("data", {}).get("output", "")) staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None if hasattr(raw_output, "content"): diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py index 21e27d4c3..51a67f369 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py @@ -15,12 +15,24 @@ def iter_completion_emission_frames( out = ctx.tool_output payload = out if isinstance(out, dict) else {"result": out} yield ctx.emit_tool_output_card(payload) - if isinstance(out, dict) and out.get("status") == "pending": + if not isinstance(out, dict): + return + status = out.get("status") + # ``ready`` is the live success status now that the tool waits for the + # Celery worker to reach a terminal state. ``pending`` is retained as a + # legacy branch for old saved chats that pre-date the wait-for-terminal + # change (see ``app.agents.shared.deliverable_wait``). + if status == "ready": + yield ctx.streaming_service.format_terminal_info( + f"Video presentation generated successfully: {out.get('title', 'Presentation')}", + "success", + ) + elif status == "pending": yield ctx.streaming_service.format_terminal_info( f"Video presentation queued: {out.get('title', 'Presentation')}", "success", ) - elif isinstance(out, dict) and out.get("status") == "failed": + elif status == "failed": error_msg = out.get("error", "Unknown error") yield ctx.streaming_service.format_terminal_info( f"Presentation generation failed: {error_msg}", diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/shared/__init__.py new file mode 100644 index 000000000..6c9f1f6b5 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/shared/__init__.py @@ -0,0 +1,15 @@ +"""Shared building blocks used across every streaming flow.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.shared.stream_result import StreamResult +from app.tasks.chat.streaming.shared.utils import ( + resume_step_prefix, + safe_float, +) + +__all__ = [ + "StreamResult", + "resume_step_prefix", + "safe_float", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py new file mode 100644 index 000000000..a940e8a9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -0,0 +1,37 @@ +"""Per-turn streaming state shared between the orchestrator and event loop.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StreamResult: + accumulated_text: str = "" + is_interrupted: bool = False + sandbox_files: list[str] = field(default_factory=list) + request_id: str | None = None + turn_id: str = "" + filesystem_mode: str = "cloud" + client_platform: str = "web" + intent_detected: str = "chat_only" + intent_confidence: float = 0.0 + write_attempted: bool = False + write_succeeded: bool = False + verification_succeeded: bool = False + commit_gate_passed: bool = True + commit_gate_reason: str = "" + # Pre-allocated assistant ``new_chat_messages.id`` for this turn, captured by + # ``persist_assistant_shell`` right after the user row is persisted. ``None`` + # for the legacy/anonymous code paths that don't opt in to server-side + # ``ContentPart[]`` projection. + assistant_message_id: int | None = None + # In-memory mirror of the FE's assistant-ui ``ContentPartsState``, populated + # by the lifecycle methods called from the streaming event loop at each + # ``streaming_service.format_*`` yield site. Snapshot in the streaming + # ``finally`` to produce the rich JSONB persisted by + # ``finalize_assistant_turn``. ``repr=False`` keeps the log-on-error path + # (``StreamResult`` is logged in some error branches) from dumping a + # potentially-large parts list. + content_builder: Any | None = field(default=None, repr=False) diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/utils.py b/surfsense_backend/app/tasks/chat/streaming/shared/utils.py new file mode 100644 index 000000000..fe6901543 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/shared/utils.py @@ -0,0 +1,27 @@ +"""Small utilities used by streaming orchestrators and phases.""" + +from __future__ import annotations + +from typing import Any + + +def resume_step_prefix(turn_id: str) -> str: + """Per-turn ``step_prefix`` for resume invocations. + + Each ``stream_agent_events`` call constructs a fresh + ``AgentEventRelayState`` with ``thinking_step_counter=0``, so two consecutive + resume turns would otherwise both emit ``thinking-resume-1``, ``-2`` etc. + The frontend rehydrates ``currentThinkingSteps`` from the immediate prior + assistant message at the start of every resume — if the new stream's IDs + collide with the seeded ones, React renders sibling Timeline rows with the + same key. Salting with ``turn_id`` guarantees disjoint IDs across resumes + within one thread. + """ + return f"thinking-resume-{turn_id}" + + +def safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default diff --git a/surfsense_backend/app/tasks/surfsense_docs_indexer.py b/surfsense_backend/app/tasks/surfsense_docs_indexer.py deleted file mode 100644 index db88c8700..000000000 --- a/surfsense_backend/app/tasks/surfsense_docs_indexer.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Surfsense documentation indexer. -Indexes MDX documentation files at startup. -""" - -import hashlib -import logging -import re -from datetime import UTC, datetime -from pathlib import Path - -from sqlalchemy import delete as sa_delete, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from sqlalchemy.orm.attributes import set_committed_value - -from app.config import config -from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker -from app.utils.document_converters import embed_text - -logger = logging.getLogger(__name__) - - -async def _safe_set_docs_chunks( - session: AsyncSession, document: SurfsenseDocsDocument, chunks: list -) -> None: - """safe_set_chunks variant for the SurfsenseDocsDocument/Chunk models.""" - if document.id is not None: - await session.execute( - sa_delete(SurfsenseDocsChunk).where( - SurfsenseDocsChunk.document_id == document.id - ) - ) - for chunk in chunks: - chunk.document_id = document.id - - set_committed_value(document, "chunks", chunks) - session.add_all(chunks) - - -# Path to docs relative to project root -DOCS_DIR = ( - Path(__file__).resolve().parent.parent.parent.parent - / "surfsense_web" - / "content" - / "docs" -) - - -def parse_mdx_frontmatter(content: str) -> tuple[str, str]: - """ - Parse MDX file to extract frontmatter title and content. - - Args: - content: Raw MDX file content - - Returns: - Tuple of (title, content_without_frontmatter) - """ - # Match frontmatter between --- markers - frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n" - match = re.match(frontmatter_pattern, content, re.DOTALL) - - if match: - frontmatter = match.group(1) - content_without_frontmatter = content[match.end() :] - - # Extract title from frontmatter - title_match = re.search(r"^title:\s*(.+)$", frontmatter, re.MULTILINE) - title = title_match.group(1).strip() if title_match else "Untitled" - - # Remove quotes if present - title = title.strip("\"'") - - return title, content_without_frontmatter.strip() - - return "Untitled", content.strip() - - -def get_all_mdx_files() -> list[Path]: - """ - Get all MDX files from the docs directory. - - Returns: - List of Path objects for each MDX file - """ - if not DOCS_DIR.exists(): - logger.warning(f"Docs directory not found: {DOCS_DIR}") - return [] - - return list(DOCS_DIR.rglob("*.mdx")) - - -def generate_surfsense_docs_content_hash(content: str) -> str: - """Generate SHA-256 hash for Surfsense docs content.""" - return hashlib.sha256(content.encode("utf-8")).hexdigest() - - -def create_surfsense_docs_chunks(content: str) -> list[SurfsenseDocsChunk]: - """ - Create chunks from Surfsense documentation content. - - Args: - content: Document content to chunk - - Returns: - List of SurfsenseDocsChunk objects with embeddings - """ - return [ - SurfsenseDocsChunk( - content=chunk.text, - embedding=embed_text(chunk.text), - ) - for chunk in config.chunker_instance.chunk(content) - ] - - -async def index_surfsense_docs(session: AsyncSession) -> tuple[int, int, int, int]: - """ - Index all Surfsense documentation files. - - Args: - session: SQLAlchemy async session - - Returns: - Tuple of (created, updated, skipped, deleted) counts - """ - created = 0 - updated = 0 - skipped = 0 - deleted = 0 - - # Get all existing docs from database - existing_docs_result = await session.execute( - select(SurfsenseDocsDocument).options( - selectinload(SurfsenseDocsDocument.chunks) - ) - ) - existing_docs = {doc.source: doc for doc in existing_docs_result.scalars().all()} - - # Track which sources we've processed - processed_sources = set() - - # Get all MDX files - mdx_files = get_all_mdx_files() - logger.info(f"Found {len(mdx_files)} MDX files to index") - - for mdx_file in mdx_files: - try: - source = str(mdx_file.relative_to(DOCS_DIR)) - processed_sources.add(source) - - # Read file content - raw_content = mdx_file.read_text(encoding="utf-8") - title, content = parse_mdx_frontmatter(raw_content) - content_hash = generate_surfsense_docs_content_hash(raw_content) - - if source in existing_docs: - existing_doc = existing_docs[source] - - # Check if content changed - if existing_doc.content_hash == content_hash: - logger.debug(f"Skipping unchanged: {source}") - skipped += 1 - continue - - # Content changed - update document - logger.info(f"Updating changed document: {source}") - - # Create new chunks - chunks = create_surfsense_docs_chunks(content) - - # Update document fields - existing_doc.title = title - existing_doc.content = content - existing_doc.content_hash = content_hash - existing_doc.embedding = embed_text(content) - await _safe_set_docs_chunks(session, existing_doc, chunks) - existing_doc.updated_at = datetime.now(UTC) - - updated += 1 - else: - # New document - create it - logger.info(f"Creating new document: {source}") - - chunks = create_surfsense_docs_chunks(content) - - document = SurfsenseDocsDocument( - source=source, - title=title, - content=content, - content_hash=content_hash, - embedding=embed_text(content), - chunks=chunks, - updated_at=datetime.now(UTC), - ) - - session.add(document) - created += 1 - - except Exception as e: - logger.error(f"Error processing {mdx_file}: {e}", exc_info=True) - continue - - # Delete documents for removed files - for source, doc in existing_docs.items(): - if source not in processed_sources: - logger.info(f"Deleting removed document: {source}") - await session.delete(doc) - deleted += 1 - - # Commit all changes - await session.commit() - - logger.info( - f"Indexing complete: {created} created, {updated} updated, " - f"{skipped} skipped, {deleted} deleted" - ) - - return created, updated, skipped, deleted - - -async def seed_surfsense_docs() -> tuple[int, int, int, int]: - """ - Seed Surfsense documentation into the database. - - This function indexes all MDX files from the docs directory. - It handles creating, updating, and deleting docs based on content changes. - - Returns: - Tuple of (created, updated, skipped, deleted) counts - Returns (0, 0, 0, 0) if an error occurs - """ - logger.info("Starting Surfsense docs indexing...") - - try: - async with async_session_maker() as session: - created, updated, skipped, deleted = await index_surfsense_docs(session) - - logger.info( - f"Surfsense docs indexing complete: " - f"created={created}, updated={updated}, skipped={skipped}, deleted={deleted}" - ) - - return created, updated, skipped, deleted - - except Exception as e: - logger.error(f"Failed to seed Surfsense docs: {e}", exc_info=True) - return 0, 0, 0, 0 diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index 9bc8103c5..059d91806 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -222,9 +222,7 @@ async def generate_document_summary( else: enhanced_summary_content = summary_content - summary_embedding = await asyncio.to_thread( - embed_text, enhanced_summary_content - ) + summary_embedding = await asyncio.to_thread(embed_text, enhanced_summary_content) return enhanced_summary_content, summary_embedding diff --git a/surfsense_backend/app/utils/perf.py b/surfsense_backend/app/utils/perf.py index b2b26897c..541ee5756 100644 --- a/surfsense_backend/app/utils/perf.py +++ b/surfsense_backend/app/utils/perf.py @@ -16,6 +16,8 @@ import time from contextlib import asynccontextmanager, contextmanager from typing import Any +from app.observability import metrics as ot_metrics + _perf_log: logging.Logger | None = None _last_rss_mb: float = 0.0 @@ -50,6 +52,7 @@ def perf_timer(label: str, *, extra: dict[str, Any] | None = None): if extra: suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items()) log.info("%s in %.3fs%s", label, elapsed, suffix) + ot_metrics.record_perf_elapsed(elapsed * 1000, label=label) @asynccontextmanager @@ -68,6 +71,7 @@ async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None): if extra: suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items()) log.info("%s in %.3fs%s", label, elapsed, suffix) + ot_metrics.record_perf_elapsed(elapsed * 1000, label=label) def system_snapshot() -> dict[str, Any]: diff --git a/surfsense_backend/app/utils/surfsense_docs.py b/surfsense_backend/app/utils/surfsense_docs.py deleted file mode 100644 index 9a6ab11a9..000000000 --- a/surfsense_backend/app/utils/surfsense_docs.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Utilities for SurfSense's built-in documentation index.""" - -from pathlib import PurePosixPath - -DOCS_PUBLIC_ROOT = PurePosixPath("/docs") - - -def surfsense_docs_public_url(source: str) -> str: - """Return the public docs route for an indexed documentation source path.""" - docs_path = PurePosixPath(source).with_suffix("") - if docs_path.name == "index": - docs_path = docs_path.parent - return (DOCS_PUBLIC_ROOT / docs_path).as_posix() diff --git a/surfsense_backend/main.py b/surfsense_backend/main.py index 4a7a9b7b1..54911a34d 100644 --- a/surfsense_backend/main.py +++ b/surfsense_backend/main.py @@ -12,9 +12,26 @@ if sys.platform == "win32": from app.config.uvicorn import load_uvicorn_config +_old_log_record_factory = logging.getLogRecordFactory() + + +def _otel_safe_log_record_factory(*args, **kwargs): + record = _old_log_record_factory(*args, **kwargs) + if not hasattr(record, "otelTraceID"): + record.otelTraceID = "0" + if not hasattr(record, "otelSpanID"): + record.otelSpanID = "0" + return record + + +logging.setLogRecordFactory(_otel_safe_log_record_factory) + logging.basicConfig( level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + format=( + "%(asctime)s - %(name)s - %(levelname)s - " + "[trace_id=%(otelTraceID)s span_id=%(otelSpanID)s] %(message)s" + ), datefmt="%Y-%m-%d %H:%M:%S", ) diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index cd2a6921a..51405ec74 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.25" +version = "0.0.26" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ @@ -76,6 +76,18 @@ dependencies = [ "litellm>=1.83.7", "langchain-litellm>=0.6.4", "deepagents>=0.4.12,<0.5", + "opentelemetry-api>=1.40.0", + "opentelemetry-sdk>=1.40.0", + "opentelemetry-exporter-otlp>=1.40.0", + "opentelemetry-semantic-conventions>=0.61b0", + "opentelemetry-instrumentation-fastapi>=0.61b0", + "opentelemetry-instrumentation-sqlalchemy>=0.61b0", + "opentelemetry-instrumentation-psycopg>=0.61b0", + "opentelemetry-instrumentation-redis>=0.61b0", + "opentelemetry-instrumentation-httpx>=0.61b0", + "opentelemetry-instrumentation-celery>=0.61b0", + "opentelemetry-instrumentation-logging>=0.61b0", + "croniter>=2.0.0", ] [dependency-groups] diff --git a/surfsense_backend/scripts/seed_surfsense_docs.py b/surfsense_backend/scripts/seed_surfsense_docs.py deleted file mode 100644 index 68899c2aa..000000000 --- a/surfsense_backend/scripts/seed_surfsense_docs.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python -""" -Seed Surfsense documentation into the database. - -CLI wrapper for the seed_surfsense_docs function. -Can be run manually for debugging or re-indexing. - -Usage: - python scripts/seed_surfsense_docs.py -""" - -import asyncio -import sys -from pathlib import Path - -# Add the parent directory to the path so we can import app modules -sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) - -from app.tasks.surfsense_docs_indexer import seed_surfsense_docs - - -def main(): - """CLI entry point for seeding Surfsense docs.""" - print("=" * 50) - print(" Surfsense Documentation Seeding") - print("=" * 50) - - created, updated, skipped, deleted = asyncio.run(seed_surfsense_docs()) - - print() - print("Results:") - print(f" Created: {created}") - print(f" Updated: {updated}") - print(f" Skipped: {skipped}") - print(f" Deleted: {deleted}") - print("=" * 50) - - -if __name__ == "__main__": - main() diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py index ac6b5d95c..2f222e148 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py @@ -60,7 +60,6 @@ class TestReadOnlyToolsAllowed: "glob", "web_search", "scrape_webpage", - "search_surfsense_docs", "get_connected_accounts", "write_todos", "task", diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py index 55434c04d..dc59c6dac 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -23,6 +23,7 @@ pytestmark = pytest.mark.unit @pytest.fixture(autouse=True) def _disable_otel(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") from app.observability import otel as ot @@ -99,16 +100,17 @@ class TestAnnotateModelResponse: "total_tokens": 150, }, ) - _annotate_model_response(sp, msg) - sp.set_attribute.assert_any_call("tokens.prompt", 100) - sp.set_attribute.assert_any_call("tokens.completion", 50) - sp.set_attribute.assert_any_call("tokens.total", 150) + assert _annotate_model_response(sp, msg) == (100, 50) + sp.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 100) + sp.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 50) + sp.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + sp.set_attribute.assert_any_call("gen_ai.operation.name", "chat") def test_handles_response_with_no_metadata(self) -> None: sp = MagicMock() msg = AIMessage(content="hello") # Should not raise even when usage_metadata is missing - _annotate_model_response(sp, msg) + assert _annotate_model_response(sp, msg) == (None, None) class TestAnnotateToolResult: @@ -119,7 +121,7 @@ class TestAnnotateToolResult: tool_call_id="abc", status="success", ) - _annotate_tool_result(sp, result) + assert _annotate_tool_result(sp, result) is False sp.set_attribute.assert_any_call("tool.output.size", len("result text")) sp.set_attribute.assert_any_call("tool.status", "success") @@ -130,7 +132,7 @@ class TestAnnotateToolResult: tool_call_id="abc", additional_kwargs={"error": {"code": "x"}}, ) - _annotate_tool_result(sp, result) + assert _annotate_tool_result(sp, result) is True sp.set_attribute.assert_any_call("tool.error", True) @@ -193,3 +195,91 @@ class TestMiddlewareIntegration: assert result.content == "enabled" finally: ot.reload_for_tests() + + async def test_enabled_model_call_records_metrics( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + duration_calls: list[dict[str, Any]] = [] + token_calls: list[dict[str, Any]] = [] + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_call_duration", + lambda duration_ms, **attrs: duration_calls.append( + {"duration_ms": duration_ms, **attrs} + ), + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_token_usage", + lambda **attrs: token_calls.append(attrs), + ) + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return AIMessage( + content="enabled", + usage_metadata={ + "input_tokens": 3, + "output_tokens": 5, + "total_tokens": 8, + }, + ) + + request = MagicMock() + request.model = MagicMock() + request.model.model_name = "gpt-4o" + request.model.provider = "openai" + await mw.awrap_model_call(request, handler) + + assert duration_calls + assert token_calls == [ + { + "input_tokens": 3, + "output_tokens": 5, + "model": "gpt-4o", + "provider": "openai", + } + ] + finally: + ot.reload_for_tests() + + async def test_enabled_tool_call_records_error_metric( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + errors: list[str] = [] + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_error", + lambda *, tool_name: errors.append(tool_name), + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_duration", + lambda *args, **kwargs: None, + ) + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return ToolMessage( + content="failed", + tool_call_id="abc", + status="error", + ) + + request = MagicMock() + request.tool = MagicMock() + request.tool.name = "web_search" + await mw.awrap_tool_call(request, handler) + assert errors == ["web_search"] + finally: + ot.reload_for_tests() diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py index 3035cc8e0..3c7fe5336 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py @@ -22,12 +22,6 @@ from app.agents.new_chat.subagents.config import ( # --------------------------------------------------------------------------- -@tool -def search_surfsense_docs(query: str) -> str: - """Search the user's KB.""" - return "" - - @tool def web_search(query: str) -> str: """Search the public web.""" @@ -95,7 +89,6 @@ def generate_report(topic: str) -> str: ALL_TOOLS = [ - search_surfsense_docs, web_search, scrape_webpage, read_file, @@ -161,7 +154,7 @@ class TestReportWriterSubagent: names = {t.name for t in spec["tools"]} # type: ignore[index] assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS} assert "generate_report" in names - assert "search_surfsense_docs" in names + assert "read_file" in names def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: spec = build_report_writer_subagent(tools=ALL_TOOLS) @@ -272,9 +265,9 @@ class TestFilterToolsWarningSuppression: # Allowed set asks for two registry tools (one present, one # not) plus a bunch of middleware-provided names. _filter_tools( - [search_surfsense_docs], + [web_search], allowed_names={ - "search_surfsense_docs", + "web_search", "scrape_webpage", # legitimately missing → should warn "read_file", # mw-provided → suppressed "ls", @@ -322,7 +315,6 @@ class TestDenyPatternsCoverage: def test_deny_patterns_do_not_match_safe_read_tools(self) -> None: canonical_reads = [ - "search_surfsense_docs", "read_file", "ls_tree", "grep", diff --git a/surfsense_backend/tests/unit/automations/__init__.py b/surfsense_backend/tests/unit/automations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/actions/__init__.py b/surfsense_backend/tests/unit/automations/actions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/__init__.py b/surfsense_backend/tests/unit/automations/actions/builtin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/__init__.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_auto_decide.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_auto_decide.py new file mode 100644 index 000000000..d8f45eadf --- /dev/null +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_auto_decide.py @@ -0,0 +1,73 @@ +"""Lock ``build_auto_decisions`` — the HITL auto-approve/reject wire mapper. + +``build_auto_decisions`` walks ``state.interrupts`` (duck-typed) and produces +two parallel resume maps: one keyed by LangGraph ``Interrupt.id`` and one +keyed by ``tool_call_id`` for the subagent middleware bridge. Both carry +the same decision payload. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from app.automations.actions.builtin.agent_task.auto_decide import build_auto_decisions + +pytestmark = pytest.mark.unit + + +def _state(interrupts: list[Any]) -> SimpleNamespace: + """Build a duck-typed LangGraph state stub carrying ``interrupts``.""" + return SimpleNamespace(interrupts=interrupts) + + +def _interrupt(*, id_: str, value: Any) -> SimpleNamespace: + """Build a duck-typed interrupt with the canonical ``(id, value)`` shape.""" + return SimpleNamespace(id=id_, value=value) + + +def test_build_auto_decisions_produces_one_decision_per_action_request() -> None: + """An interrupt carrying N ``action_requests`` produces N decisions of + the requested type in both maps. This is the canonical batched-HITL + wire shape — losing a decision would leave a pending action stuck.""" + interrupt = _interrupt( + id_="lg-1", + value={ + "tool_call_id": "tc-1", + "action_requests": [{"id": "a"}, {"id": "b"}], + }, + ) + + lg_map, routed = build_auto_decisions(_state([interrupt]), "approve") + + assert lg_map == {"lg-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}} + assert routed == {"tc-1": {"decisions": [{"type": "approve"}, {"type": "approve"}]}} + + +def test_build_auto_decisions_defaults_to_one_decision_for_scalar_interrupt() -> None: + """When an interrupt's value has no ``action_requests`` list, the + function defaults to a single decision. Locks compatibility with + older single-action interrupt shapes still emitted by some tools.""" + interrupt = _interrupt(id_="lg-2", value={"tool_call_id": "tc-2"}) + + lg_map, routed = build_auto_decisions(_state([interrupt]), "reject") + + assert lg_map == {"lg-2": {"decisions": [{"type": "reject"}]}} + assert routed == {"tc-2": {"decisions": [{"type": "reject"}]}} + + +def test_build_auto_decisions_skips_interrupts_with_invalid_shape() -> None: + """Interrupts missing the canonical ``(str id, dict value)`` shape are + skipped silently rather than crashing the resume loop. Locks the + resilience contract — a malformed interrupt from a misbehaving tool + shouldn't take down the whole agent_task step.""" + good = _interrupt(id_="lg-good", value={"tool_call_id": "tc-good"}) + bad_value = _interrupt(id_="lg-bad-value", value="not a dict") + bad_id = _interrupt(id_=None, value={"tool_call_id": "tc-bad-id"}) # type: ignore[arg-type] + + lg_map, routed = build_auto_decisions(_state([good, bad_value, bad_id]), "approve") + + assert lg_map == {"lg-good": {"decisions": [{"type": "approve"}]}} + assert routed == {"tc-good": {"decisions": [{"type": "approve"}]}} diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py new file mode 100644 index 000000000..ac20b2608 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -0,0 +1,174 @@ +"""Lock the runtime model-policy backstop in ``build_dependencies``. + +Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so +runs are insulated from later chat/search-space model changes), and the model +policy is re-checked at run time so a captured model that is no longer billable +fails the run clearly. When no snapshot is present, resolution falls back to the +live search space. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +import app.automations.actions.agent_task.dependencies as deps_mod +from app.automations.actions.agent_task.dependencies import ( + DependencyError, + build_dependencies, +) +from app.automations.services.model_policy import AutomationModelPolicyError + +pytestmark = pytest.mark.unit + + +class _FakeSession: + """Minimal async session whose ``get`` returns a preset search space.""" + + def __init__(self, search_space: Any) -> None: + self._search_space = search_space + + async def get(self, _model: Any, _pk: int) -> Any: + return self._search_space + + +@pytest.fixture +def patched_side_effects(monkeypatch: pytest.MonkeyPatch): + """Stub the connector setup + checkpointer so only policy/LLM logic runs.""" + + async def _fake_setup(_session, *, search_space_id): + return (SimpleNamespace(name="connector"), "fc-key") + + monkeypatch.setattr(deps_mod, "setup_connector_and_firecrawl", _fake_setup) + return None + + +async def test_build_dependencies_resolves_captured_agent_llm_id( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" + captured: dict[str, Any] = {} + + async def _fake_load(_session, *, config_id, search_space_id): + captured["config_id"] = config_id + captured["search_space_id"] = search_space_id + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + # Captured path validates the explicit ids; passes for this test. + monkeypatch.setattr(deps_mod, "assert_models_billable", lambda **_kw: None) + # A different value on the live search space proves we ignore it when a + # snapshot is supplied. + monkeypatch.setattr( + deps_mod, + "assert_automation_models_billable", + lambda _ss: pytest.fail("search-space policy should not run on captured path"), + ) + + search_space = SimpleNamespace(agent_llm_id=-99) + result = await build_dependencies( + session=_FakeSession(search_space), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + + assert captured == {"config_id": -7, "search_space_id": 42} + assert result.llm.name == "llm" + assert result.firecrawl_api_key == "fc-key" + + +async def test_build_dependencies_validates_captured_ids( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """The captured ids (not the search space) are what gets policy-checked.""" + seen: dict[str, Any] = {} + + def _capture(**kwargs): + seen.update(kwargs) + + monkeypatch.setattr(deps_mod, "assert_models_billable", _capture) + + async def _fake_load(_session, *, config_id, search_space_id): + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + + await build_dependencies( + session=_FakeSession(SimpleNamespace(agent_llm_id=0)), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + + assert seen == { + "agent_llm_id": -7, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + +async def test_build_dependencies_raises_on_captured_policy_violation( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """A blocked captured model raises ``DependencyError`` so the step fails clearly.""" + + def _raise(**_kw): + raise AutomationModelPolicyError( + [{"kind": "image", "config_id": -2, "reason": "free model"}] + ) + + monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) + monkeypatch.setattr( + deps_mod, + "load_llm_bundle", + lambda *a, **k: pytest.fail("load_llm_bundle should not be called"), + ) + + with pytest.raises(DependencyError): + await build_dependencies( + session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=-2, + vision_llm_config_id=-1, + ) + + +async def test_build_dependencies_falls_back_to_search_space( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """With no captured snapshot, resolve + validate the live search space.""" + captured: dict[str, Any] = {} + + async def _fake_load(_session, *, config_id, search_space_id): + captured["config_id"] = config_id + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + monkeypatch.setattr(deps_mod, "assert_automation_models_billable", lambda _ss: None) + monkeypatch.setattr( + deps_mod, + "assert_models_billable", + lambda **_kw: pytest.fail("captured policy should not run on fallback path"), + ) + + search_space = SimpleNamespace(agent_llm_id=-7) + result = await build_dependencies( + session=_FakeSession(search_space), search_space_id=42 + ) + + assert captured == {"config_id": -7} + assert result.llm.name == "llm" + + +async def test_build_dependencies_raises_when_search_space_missing( + patched_side_effects, +) -> None: + """A missing search space (fallback path) surfaces as a ``DependencyError``.""" + with pytest.raises(DependencyError): + await build_dependencies(session=_FakeSession(None), search_space_id=999) diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_finalize.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_finalize.py new file mode 100644 index 000000000..9e2143438 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_finalize.py @@ -0,0 +1,86 @@ +"""Lock ``extract_final_assistant_message`` — what surfaces in run output. + +Each scenario is one shape the agent runtime is observed to produce. +Locking these means we can refactor the extractor without losing +backwards compatibility with already-stored ``run.output`` payloads. +""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.automations.actions.builtin.agent_task.finalize import ( + extract_final_assistant_message, +) + +pytestmark = pytest.mark.unit + + +def test_extract_returns_last_ai_message_string_content() -> None: + """The canonical shape: the agent's final ``AIMessage`` carries a + plain string. That string is returned verbatim, trimmed.""" + result = { + "messages": [ + HumanMessage(content="ask"), + AIMessage(content="the answer"), + ] + } + + assert extract_final_assistant_message(result) == "the answer" + + +def test_extract_concatenates_text_parts_and_skips_non_text_parts() -> None: + """Multi-part AIMessage content (Anthropic / OpenAI list shape) joins + its ``text`` parts in order; non-text parts (tool_use, images, ...) + are skipped. Locks the wire shape used when the model emits tool + calls alongside narrative text in the same turn.""" + result = { + "messages": [ + AIMessage( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "tool_use", "name": "search", "input": {}}, + {"type": "text", "text": "world"}, + ] + ) + ] + } + + assert extract_final_assistant_message(result) == "Hello world" + + +def test_extract_returns_last_ai_message_skipping_tool_messages() -> None: + """When the transcript ends with tool calls and tool results, the + extractor still walks back to the **last** ``AIMessage`` (the agent's + final narrative answer). Locks resilience against trailing + ``ToolMessage`` payloads in the transcript.""" + result = { + "messages": [ + HumanMessage(content="ask"), + AIMessage(content="thinking..."), + ToolMessage(content="tool output", tool_call_id="tc-1"), + AIMessage(content="final answer"), + ToolMessage(content="trailing tool noise", tool_call_id="tc-2"), + ] + } + + assert extract_final_assistant_message(result) == "final answer" + + +def test_extract_returns_none_when_no_assistant_text_is_present() -> None: + """No ``AIMessage`` with extractable text → ``None`` rather than the + empty string. Lets callers branch on "did the agent actually say + anything?" rather than guess whether ``""`` means silence or empty + output. Empty-string contents are normalized to ``None`` too.""" + no_ai = {"messages": [HumanMessage(content="just a question")]} + only_tools = { + "messages": [ + AIMessage(content=[{"type": "tool_use", "name": "x", "input": {}}]) + ] + } + empty_string = {"messages": [AIMessage(content=" ")]} + + assert extract_final_assistant_message(no_ai) is None + assert extract_final_assistant_message(only_tools) is None + assert extract_final_assistant_message(empty_string) is None diff --git a/surfsense_backend/tests/unit/automations/conftest.py b/surfsense_backend/tests/unit/automations/conftest.py new file mode 100644 index 000000000..0fbf03234 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/conftest.py @@ -0,0 +1,39 @@ +"""Shared fixtures for the ``app.automations`` unit-test tree. + +Provides registry isolation: the built-in ``schedule`` trigger and +``agent_task`` action self-register at import time. Tests that register +additional triggers/actions (or assert on the registry contents) must +not leak that state to other tests. These fixtures snapshot and restore +the module-level registry dicts. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import pytest + +from app.automations.actions import store as action_store +from app.automations.triggers import store as trigger_store + + +@pytest.fixture +def isolated_action_registry() -> Iterator[None]: + """Snapshot and restore the action registry around a test.""" + snapshot = dict(action_store._REGISTRY) + try: + yield + finally: + action_store._REGISTRY.clear() + action_store._REGISTRY.update(snapshot) + + +@pytest.fixture +def isolated_trigger_registry() -> Iterator[None]: + """Snapshot and restore the trigger registry around a test.""" + snapshot = dict(trigger_store._REGISTRY) + try: + yield + finally: + trigger_store._REGISTRY.clear() + trigger_store._REGISTRY.update(snapshot) diff --git a/surfsense_backend/tests/unit/automations/dispatch/__init__.py b/surfsense_backend/tests/unit/automations/dispatch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/dispatch/test_errors.py b/surfsense_backend/tests/unit/automations/dispatch/test_errors.py new file mode 100644 index 000000000..89c1bede9 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/dispatch/test_errors.py @@ -0,0 +1,28 @@ +"""Lock the ``DispatchError`` exception contract. + +``DispatchError`` is the uniform exception type the dispatch layer raises +for any "cannot turn this fire request into a run" condition. Other +modules (templates of error envelopes, run records) compare on +``isinstance(exc, DispatchError)``, so the inheritance is the contract. +""" + +from __future__ import annotations + +import pytest + +from app.automations.dispatch.errors import DispatchError + +pytestmark = pytest.mark.unit + + +def test_dispatch_error_is_exception_subclass_and_carries_message() -> None: + """Lifting a string into ``DispatchError`` preserves the message and + behaves as a regular ``Exception`` for ``isinstance`` / ``raise`` / + ``except`` consumers.""" + error = DispatchError("missing trigger") + + assert isinstance(error, Exception) + assert str(error) == "missing trigger" + + with pytest.raises(DispatchError): + raise error diff --git a/surfsense_backend/tests/unit/automations/dispatch/test_inputs.py b/surfsense_backend/tests/unit/automations/dispatch/test_inputs.py new file mode 100644 index 000000000..2744982a0 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/dispatch/test_inputs.py @@ -0,0 +1,74 @@ +"""Lock the input-validation contract enforced before a run is enqueued. + +``validate_inputs`` is the pure schema check that ``enqueue_run`` runs against +merged inputs. ``enqueue_run`` itself needs a real DB session, so tests target +this pure function directly; the contract — not the symbol — is what's locked. +""" + +from __future__ import annotations + +import pytest + +from app.automations.dispatch.errors import DispatchError +from app.automations.dispatch.inputs import validate_inputs +from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.schemas.definition.inputs import Inputs +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +def _minimal_definition(*, inputs: Inputs | None = None) -> AutomationDefinition: + """One-step definition with an optional declared input schema.""" + return AutomationDefinition( + name="test", + inputs=inputs, + plan=[PlanStep(step_id="s1", action="agent_task")], + ) + + +def test_validate_inputs_passes_through_when_no_schema_is_declared() -> None: + """When the definition declares no input schema, runtime inputs reach + the template context **unchanged**. Regression site: previously this + branch returned ``{}``, which stripped runtime keys like ``fired_at`` + and ``last_fired_at`` and made Jinja blow up on ``{{ inputs.* }}``. + """ + definition = _minimal_definition(inputs=None) + runtime_inputs = { + "fired_at": "2026-01-01T00:00:00+00:00", + "last_fired_at": None, + "static_key": "value", + } + + assert validate_inputs(definition, runtime_inputs) == runtime_inputs + + +def test_validate_inputs_returns_inputs_when_they_match_declared_schema() -> None: + """With a declared JSON schema, inputs that satisfy it pass through + unchanged (validation succeeds; the function does not coerce or + strip extra fields not mentioned in the schema).""" + schema = { + "type": "object", + "properties": {"topic": {"type": "string"}}, + "required": ["topic"], + } + definition = _minimal_definition(inputs=Inputs(schema=schema)) + + inputs = {"topic": "weekly report"} + + assert validate_inputs(definition, inputs) == inputs + + +def test_validate_inputs_raises_dispatch_error_when_inputs_violate_schema() -> None: + """Inputs that don't match the declared schema must surface as + ``DispatchError`` (not the raw ``jsonschema.ValidationError``), so every + caller can handle one dispatch-domain exception type uniformly.""" + schema = { + "type": "object", + "properties": {"topic": {"type": "string"}}, + "required": ["topic"], + } + definition = _minimal_definition(inputs=Inputs(schema=schema)) + + with pytest.raises(DispatchError): + validate_inputs(definition, {"topic": 42}) # type violates string diff --git a/surfsense_backend/tests/unit/automations/runtime/__init__.py b/surfsense_backend/tests/unit/automations/runtime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/runtime/test_execute_step.py b/surfsense_backend/tests/unit/automations/runtime/test_execute_step.py new file mode 100644 index 000000000..9b203fdba --- /dev/null +++ b/surfsense_backend/tests/unit/automations/runtime/test_execute_step.py @@ -0,0 +1,272 @@ +"""Lock the ``execute_step`` orchestration contract. + +Covers the pure step-execution logic: predicate gate, params rendering, +action lookup, retry budget, error shaping. The ``ActionContext.session`` +is never touched by ``execute_step`` itself (it's only forwarded to the +handler), so unit tests pass ``None`` cast to the type. +""" + +from __future__ import annotations + +from typing import Any, cast + +import pytest +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.actions.store import register_action +from app.automations.actions.types import ActionContext, ActionDefinition +from app.automations.runtime.step import execute_step +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +class _AnyParams(BaseModel): + """Open params model used by test actions — they never validate.""" + + model_config = {"extra": "allow"} + + +def _action_context() -> ActionContext: + """Minimal context: session is unused by ``execute_step``, only forwarded.""" + return ActionContext( + session=cast(AsyncSession, None), + run_id=1, + step_id="s1", + search_space_id=1, + creator_user_id=None, + ) + + +async def test_execute_step_runs_registered_action_handler_and_wraps_result( + isolated_action_registry: None, +) -> None: + """A step pointing at a registered action runs its handler with the + step's params and returns a ``succeeded`` entry carrying the handler's + output plus ``attempts=1`` (one try, no retries triggered).""" + invocations: list[dict[str, Any]] = [] + + async def echo(params: dict[str, Any]) -> dict[str, Any]: + invocations.append(params) + return {"echoed": params["value"]} + + register_action( + ActionDefinition( + type="test_echo", + name="Echo", + description="Test action.", + params_model=_AnyParams, + build_handler=lambda _ctx: echo, + ) + ) + + step = PlanStep(step_id="s1", action="test_echo", params={"value": "hello"}) + + result = await execute_step( + step=step, + template_context={}, + action_context=_action_context(), + default_max_retries=0, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert result["status"] == "succeeded" + assert result["step_id"] == "s1" + assert result["action"] == "test_echo" + assert result["attempts"] == 1 + assert result["result"] == {"echoed": "hello"} + assert invocations == [{"value": "hello"}] + + +async def test_execute_step_skips_step_when_predicate_is_falsy( + isolated_action_registry: None, +) -> None: + """If ``step.when`` evaluates to falsy in the template context, the + handler is **not** invoked, the result entry has ``status=skipped`` + and ``attempts=0``, and no ``result`` key is present.""" + invoked = False + + async def must_not_run(_params: dict[str, Any]) -> dict[str, Any]: + nonlocal invoked + invoked = True + return {} + + register_action( + ActionDefinition( + type="test_guarded", + name="Guarded", + description="Test action that should not run.", + params_model=_AnyParams, + build_handler=lambda _ctx: must_not_run, + ) + ) + + step = PlanStep( + step_id="s1", + action="test_guarded", + when="inputs.enabled", + params={}, + ) + + result = await execute_step( + step=step, + template_context={"inputs": {"enabled": False}}, + action_context=_action_context(), + default_max_retries=0, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert result["status"] == "skipped" + assert result["attempts"] == 0 + assert "result" not in result + assert invoked is False + + +async def test_execute_step_fails_when_step_references_an_unknown_action( + isolated_action_registry: None, +) -> None: + """A step pointing at an action that isn't in the registry must fail + with ``ActionNotFound`` rather than crashing. Catches typos in the + plan and removed actions without the run going off the rails.""" + step = PlanStep(step_id="s1", action="no_such_action", params={}) + + result = await execute_step( + step=step, + template_context={}, + action_context=_action_context(), + default_max_retries=0, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert result["status"] == "failed" + assert result["attempts"] == 0 + assert result["error"]["type"] == "ActionNotFound" + assert "no_such_action" in result["error"]["message"] + + +async def test_execute_step_retries_failing_handler_up_to_default_budget( + isolated_action_registry: None, +) -> None: + """A handler that raises on every attempt consumes the retry budget + (1 initial try + ``default_max_retries`` retries) and the step ends + ``failed`` with the exception's type and message surfaced through + the error envelope.""" + calls = 0 + + async def always_fails(_params: dict[str, Any]) -> dict[str, Any]: + nonlocal calls + calls += 1 + raise RuntimeError("boom") + + register_action( + ActionDefinition( + type="test_fails", + name="Fails", + description="Always raises.", + params_model=_AnyParams, + build_handler=lambda _ctx: always_fails, + ) + ) + + step = PlanStep(step_id="s1", action="test_fails", params={}) + + result = await execute_step( + step=step, + template_context={}, + action_context=_action_context(), + default_max_retries=2, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert result["status"] == "failed" + assert result["attempts"] == 3 + assert calls == 3 + assert result["error"]["type"] == "RuntimeError" + assert "boom" in result["error"]["message"] + + +async def test_execute_step_succeeds_when_handler_recovers_within_retry_budget( + isolated_action_registry: None, +) -> None: + """A handler that fails the first N times and then succeeds yields a + ``succeeded`` entry with ``attempts == N + 1``. Locks that retries + can actually recover (not just exhaust).""" + calls = 0 + + async def flaky(_params: dict[str, Any]) -> dict[str, Any]: + nonlocal calls + calls += 1 + if calls < 3: + raise RuntimeError("transient") + return {"ok": True} + + register_action( + ActionDefinition( + type="test_flaky", + name="Flaky", + description="Fails twice, succeeds third time.", + params_model=_AnyParams, + build_handler=lambda _ctx: flaky, + ) + ) + + step = PlanStep(step_id="s1", action="test_flaky", params={}) + + result = await execute_step( + step=step, + template_context={}, + action_context=_action_context(), + default_max_retries=2, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert result["status"] == "succeeded" + assert result["attempts"] == 3 + assert result["result"] == {"ok": True} + assert calls == 3 + + +async def test_execute_step_renders_step_params_through_template_engine( + isolated_action_registry: None, +) -> None: + """Step params are rendered against the template context before the + handler is invoked. String values containing Jinja expressions get + substituted from ``inputs`` and ``steps`` in the run context.""" + received: list[dict[str, Any]] = [] + + async def capture(params: dict[str, Any]) -> dict[str, Any]: + received.append(params) + return {} + + register_action( + ActionDefinition( + type="test_capture", + name="Capture", + description="Captures the params passed in.", + params_model=_AnyParams, + build_handler=lambda _ctx: capture, + ) + ) + + step = PlanStep( + step_id="s1", + action="test_capture", + params={"message": "Hello {{ inputs.name }}"}, + ) + + await execute_step( + step=step, + template_context={"inputs": {"name": "World"}, "steps": {}}, + action_context=_action_context(), + default_max_retries=0, + default_retry_backoff="none", + default_timeout_seconds=30, + ) + + assert received == [{"message": "Hello World"}] diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py new file mode 100644 index 000000000..d7e3c4a0c --- /dev/null +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -0,0 +1,59 @@ +"""Lock that the executor propagates the captured model snapshot into the +``ActionContext``, so runs resolve their own model (insulated from chat / +search-space changes) and not the live search space. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import cast + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.runtime.executor import _build_action_ctx +from app.automations.schemas.definition.envelope import AutomationModels +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +def _run() -> SimpleNamespace: + return SimpleNamespace( + id=1, + automation=SimpleNamespace(search_space_id=42, created_by_user_id="u-1"), + ) + + +def test_build_action_ctx_propagates_captured_models() -> None: + """``definition.models`` flows onto the ActionContext model fields.""" + models = AutomationModels( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + ctx = _build_action_ctx( + cast(AsyncSession, None), + _run(), + PlanStep(step_id="s1", action="agent_task"), + models, + ) + + assert ctx.search_space_id == 42 + assert ctx.agent_llm_id == -1 + assert ctx.image_generation_config_id == 5 + assert ctx.vision_llm_config_id == -1 + + +def test_build_action_ctx_none_models_leaves_fields_none() -> None: + """No captured snapshot → model fields are None (defensive fallback path).""" + ctx = _build_action_ctx( + cast(AsyncSession, None), + _run(), + PlanStep(step_id="s1", action="agent_task"), + None, + ) + + assert ctx.agent_llm_id is None + assert ctx.image_generation_config_id is None + assert ctx.vision_llm_config_id is None diff --git a/surfsense_backend/tests/unit/automations/runtime/test_retries.py b/surfsense_backend/tests/unit/automations/runtime/test_retries.py new file mode 100644 index 000000000..05fd02ab6 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/runtime/test_retries.py @@ -0,0 +1,74 @@ +"""Lock the ``with_retries`` policy: budget, recovery, exhaustion, timeout, backoff. + +Tests with ``backoff="none"`` to keep wall-clock time zero. Backoff sleep +values themselves are observed by monkeypatching ``asyncio.sleep`` so we +don't introduce flakiness via real timing. +""" + +from __future__ import annotations + +import pytest + +from app.automations.runtime.retries import with_retries + +pytestmark = pytest.mark.unit + + +async def test_with_retries_returns_result_and_attempts_one_on_first_success() -> None: + """A coroutine that succeeds on the first call returns its result + paired with ``attempts=1`` — no retry consumed.""" + calls = 0 + + async def succeed() -> str: + nonlocal calls + calls += 1 + return "ok" + + result, attempts = await with_retries( + succeed, max_retries=2, backoff="none", timeout=None + ) + + assert result == "ok" + assert attempts == 1 + assert calls == 1 + + +async def test_with_retries_returns_attempt_count_when_succeeding_after_failures() -> ( + None +): + """A coroutine that fails twice then succeeds returns ``attempts=3`` + (the actual attempt that produced the result). Locks the contract + that the caller can distinguish first-try success from a recovery.""" + calls = 0 + + async def flaky() -> str: + nonlocal calls + calls += 1 + if calls < 3: + raise RuntimeError("transient") + return "ok" + + result, attempts = await with_retries( + flaky, max_retries=5, backoff="none", timeout=None + ) + + assert result == "ok" + assert attempts == 3 + assert calls == 3 + + +async def test_with_retries_reraises_after_exhausting_the_budget() -> None: + """When the coroutine raises on every attempt within + ``1 + max_retries`` tries, the last exception propagates and the + handler is called exactly ``1 + max_retries`` times.""" + calls = 0 + + async def always_fails() -> str: + nonlocal calls + calls += 1 + raise RuntimeError(f"boom-{calls}") + + with pytest.raises(RuntimeError, match="boom-3"): + await with_retries(always_fails, max_retries=2, backoff="none", timeout=None) + + assert calls == 3 # 1 initial + 2 retries diff --git a/surfsense_backend/tests/unit/automations/schemas/__init__.py b/surfsense_backend/tests/unit/automations/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/schemas/api/__init__.py b/surfsense_backend/tests/unit/automations/schemas/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/schemas/api/test_api_automation.py b/surfsense_backend/tests/unit/automations/schemas/api/test_api_automation.py new file mode 100644 index 000000000..6ae3ce794 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/api/test_api_automation.py @@ -0,0 +1,82 @@ +"""Lock the request-side automation API schemas — the public validation gate.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.api.automation import AutomationCreate, AutomationUpdate + +pytestmark = pytest.mark.unit + + +_VALID_DEFINITION = { + "name": "Test", + "plan": [{"step_id": "s1", "action": "agent_task"}], +} + + +def test_automation_create_accepts_valid_minimal_payload() -> None: + """Happy path: just search_space_id, name, and a valid definition. + Triggers default to ``[]`` so users can attach them later.""" + payload = AutomationCreate.model_validate( + { + "search_space_id": 1, + "name": "Daily digest", + "definition": _VALID_DEFINITION, + } + ) + + assert payload.name == "Daily digest" + assert payload.description is None + assert payload.triggers == [] + + +def test_automation_create_cascades_validation_into_nested_definition() -> None: + """A bad ``definition`` (e.g. empty plan) fails at the API boundary, + not at the DB layer. Locks the cascade so corrupt definitions can't + sneak through a misshapen wire payload.""" + with pytest.raises(ValidationError): + AutomationCreate.model_validate( + { + "search_space_id": 1, + "name": "Bad", + "definition": {"name": "X", "plan": []}, # empty plan + } + ) + + +def test_automation_create_rejects_unknown_top_level_field() -> None: + """``extra='forbid'`` catches typos in API payloads at the boundary.""" + with pytest.raises(ValidationError): + AutomationCreate.model_validate( + { + "search_space_id": 1, + "name": "X", + "definition": _VALID_DEFINITION, + "owner": "tg", # not allowed + } + ) + + +def test_automation_create_rejects_empty_name() -> None: + """Name is required and constrained to 1..200 chars.""" + with pytest.raises(ValidationError): + AutomationCreate.model_validate( + { + "search_space_id": 1, + "name": "", + "definition": _VALID_DEFINITION, + } + ) + + +def test_automation_update_accepts_partial_payload_with_no_fields() -> None: + """All fields on ``AutomationUpdate`` are optional. An empty body is + a valid no-op update (the service layer decides what to do with it).""" + update = AutomationUpdate.model_validate({}) + + assert update.name is None + assert update.description is None + assert update.status is None + assert update.definition is None diff --git a/surfsense_backend/tests/unit/automations/schemas/api/test_api_trigger.py b/surfsense_backend/tests/unit/automations/schemas/api/test_api_trigger.py new file mode 100644 index 000000000..cabfc41af --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/api/test_api_trigger.py @@ -0,0 +1,47 @@ +"""Lock the request-side trigger API schemas.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.schemas.api.trigger import TriggerCreate, TriggerUpdate + +pytestmark = pytest.mark.unit + + +def test_trigger_create_uses_safe_defaults_for_optional_fields() -> None: + """Defaults: empty ``params`` and ``static_inputs``, ``enabled=True``. + These let callers create a trigger with just ``type`` + the params + the trigger requires.""" + trigger = TriggerCreate(type=TriggerType.SCHEDULE) # type: ignore[arg-type] + + assert trigger.type is TriggerType.SCHEDULE + assert trigger.params == {} + assert trigger.static_inputs == {} + assert trigger.enabled is True + + +def test_trigger_create_rejects_unknown_trigger_type_string() -> None: + """``type`` is a ``TriggerType`` enum, so any string outside the + enum's known values fails validation at the boundary.""" + with pytest.raises(ValidationError): + TriggerCreate.model_validate({"type": "webhook"}) # not in TriggerType + + +def test_trigger_create_rejects_unknown_field() -> None: + """``extra='forbid'`` catches typos in trigger payloads.""" + with pytest.raises(ValidationError): + TriggerCreate.model_validate( + {"type": "schedule", "param": {}} # typo: param vs params + ) + + +def test_trigger_update_accepts_partial_payload_with_no_fields() -> None: + """``TriggerUpdate`` is fully optional — empty body is valid (no-op).""" + update = TriggerUpdate() + + assert update.enabled is None + assert update.params is None + assert update.static_inputs is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/__init__.py b/surfsense_backend/tests/unit/automations/schemas/definition/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py new file mode 100644 index 000000000..25e193ffa --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -0,0 +1,90 @@ +"""Lock the ``AutomationDefinition`` envelope contract.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.definition.envelope import ( + AutomationDefinition, + AutomationModels, +) +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +def test_automation_definition_accepts_minimal_valid_input_with_sensible_defaults() -> ( + None +): + """A definition with just ``name`` + a one-step ``plan`` is valid and + fills in the rest with safe defaults so users don't have to write + out every section to get started.""" + definition = AutomationDefinition( + name="Daily digest", + plan=[PlanStep(step_id="s1", action="agent_task")], + ) + + assert definition.name == "Daily digest" + assert definition.schema_version == "1.0" + assert definition.goal is None + assert definition.inputs is None + assert definition.triggers == [] + # ``models`` is optional (populated server-side at create()). + assert definition.models is None + + +def test_automation_definition_models_round_trip() -> None: + """The captured ``models`` snapshot survives a model_dump/validate round-trip.""" + definition = AutomationDefinition( + name="Daily digest", + plan=[PlanStep(step_id="s1", action="agent_task")], + models=AutomationModels( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ), + ) + + dumped = definition.model_dump(mode="json", by_alias=True) + assert dumped["models"] == { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + restored = AutomationDefinition.model_validate(dumped) + assert restored.models is not None + assert restored.models.agent_llm_id == -1 + assert restored.models.image_generation_config_id == 5 + assert restored.models.vision_llm_config_id == -1 + + +def test_automation_definition_rejects_unknown_top_level_field() -> None: + """``extra='forbid'`` catches typos at validation time (e.g. ``pln`` + instead of ``plan``) before the bad definition reaches storage.""" + with pytest.raises(ValidationError): + AutomationDefinition.model_validate( + { + "name": "X", + "plan": [{"step_id": "s1", "action": "agent_task"}], + "extra_field": "unexpected", + } + ) + + +def test_automation_definition_rejects_empty_plan() -> None: + """An automation with no plan steps has nothing to execute and must + be rejected at validation time.""" + with pytest.raises(ValidationError): + AutomationDefinition(name="X", plan=[]) + + +def test_automation_definition_rejects_empty_name() -> None: + """Name is required and must be non-empty so list views and audit + logs have something meaningful to display.""" + with pytest.raises(ValidationError): + AutomationDefinition( + name="", + plan=[PlanStep(step_id="s1", action="agent_task")], + ) diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_execution.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_execution.py new file mode 100644 index 000000000..15adefab0 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_execution.py @@ -0,0 +1,49 @@ +"""Lock the ``Execution`` defaults + literal-constraint contract. + +These defaults control production behavior of every automation that +doesn't override them; the defaults *are* the contract. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.definition.execution import Execution + +pytestmark = pytest.mark.unit + + +def test_execution_uses_production_defaults_when_no_overrides_provided() -> None: + """The defaults shipped to prod: 10-minute wall clock, 2 retries + per step, exponential backoff, drop overlapping runs. Changing any + of these is a behavioral release-note change.""" + execution = Execution() + + assert execution.timeout_seconds == 600 + assert execution.max_retries == 2 + assert execution.retry_backoff == "exponential" + assert execution.concurrency == "drop_if_running" + assert execution.on_failure == [] + + +def test_execution_rejects_unknown_retry_backoff_strategy() -> None: + """``retry_backoff`` is constrained to a closed set — typos like + ``"expontential"`` must fail validation, not silently coerce.""" + with pytest.raises(ValidationError): + Execution(retry_backoff="expontential") # type: ignore[arg-type] + + +def test_execution_rejects_unknown_concurrency_strategy() -> None: + """Same closed-set constraint on ``concurrency``.""" + with pytest.raises(ValidationError): + Execution(concurrency="parallel") # type: ignore[arg-type] + + +def test_execution_rejects_invalid_numeric_bounds() -> None: + """``timeout_seconds > 0`` and ``max_retries >= 0``. Zero or negative + values would produce nonsensical run behavior.""" + with pytest.raises(ValidationError): + Execution(timeout_seconds=0) + with pytest.raises(ValidationError): + Execution(max_retries=-1) diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_inputs.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_inputs.py new file mode 100644 index 000000000..5dc24463f --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_inputs.py @@ -0,0 +1,39 @@ +"""Lock the ``Inputs`` JSON ``schema``-alias roundtrip. + +The field is ``schema_`` in Python (``schema`` shadows a Pydantic builtin) +but is wire-named ``schema``. Locking the roundtrip means JSON definitions +authored anywhere (UI raw editor, NL drafter, CLI export) speak the same +wire shape. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.definition.inputs import Inputs + +pytestmark = pytest.mark.unit + + +def test_inputs_parses_wire_field_named_schema_into_schema_attribute() -> None: + """JSON payloads use ``schema`` (the convention). The model stores it + on the Python attribute ``schema_`` without shadowing the builtin.""" + parsed = Inputs.model_validate({"schema": {"type": "object"}}) + + assert parsed.schema_ == {"type": "object"} + + +def test_inputs_serializes_schema_attribute_back_to_wire_field_named_schema() -> None: + """Round-trip: serializing emits ``schema`` (alias), not ``schema_``. + Locks the consumer-visible JSON shape regardless of the Python name.""" + inputs = Inputs(schema={"type": "object"}) # type: ignore[call-arg] + + assert inputs.model_dump() == {"schema": {"type": "object"}} + + +def test_inputs_rejects_unknown_field() -> None: + """``extra='forbid'`` catches typos like ``shema`` so bad definitions + don't silently lose their input declaration.""" + with pytest.raises(ValidationError): + Inputs.model_validate({"schema": {}, "extra": "x"}) diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_metadata.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_metadata.py new file mode 100644 index 000000000..9ac90bb3f --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_metadata.py @@ -0,0 +1,37 @@ +"""Lock the ``Metadata`` ``extra='allow'`` contract — the only schema +that does. Free-form annotations on definitions (e.g. ``owner``, +``project``, ``created_by_ai``) need to round-trip through the envelope +without being rejected. +""" + +from __future__ import annotations + +import pytest + +from app.automations.schemas.definition.metadata import Metadata + +pytestmark = pytest.mark.unit + + +def test_metadata_preserves_unknown_keys() -> None: + """Unlike every other definition sub-schema, ``Metadata`` allows + extra keys and round-trips them — that's its purpose.""" + metadata = Metadata.model_validate( + { + "tags": ["weekly", "report"], + "owner": "tg", + "created_by_ai": True, + } + ) + + dumped = metadata.model_dump() + + assert dumped["tags"] == ["weekly", "report"] + assert dumped["owner"] == "tg" + assert dumped["created_by_ai"] is True + + +def test_metadata_defaults_tags_to_empty_list() -> None: + """No tags is the common case; the default is the empty list so + callers can append without a None check.""" + assert Metadata().tags == [] diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_plan_step.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_plan_step.py new file mode 100644 index 000000000..6896a7f5a --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_plan_step.py @@ -0,0 +1,52 @@ +"""Lock the ``PlanStep`` validation contract.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +def test_plan_step_accepts_minimal_input_with_safe_defaults() -> None: + """A step with just ``step_id`` + ``action`` is valid. Defaults + (no when, empty params, no output_as override, no retry/timeout + override) let the run inherit automation-wide defaults.""" + step = PlanStep(step_id="s1", action="agent_task") + + assert step.step_id == "s1" + assert step.action == "agent_task" + assert step.when is None + assert step.params == {} + assert step.output_as is None + assert step.max_retries is None + assert step.timeout_seconds is None + + +def test_plan_step_rejects_empty_step_id_and_action() -> None: + """``step_id`` and ``action`` are addressing primitives — empty + strings would silently break runtime lookups.""" + with pytest.raises(ValidationError): + PlanStep(step_id="", action="agent_task") + with pytest.raises(ValidationError): + PlanStep(step_id="s1", action="") + + +def test_plan_step_rejects_negative_max_retries_and_non_positive_timeout() -> None: + """Numeric constraints: ``max_retries >= 0`` and ``timeout_seconds > 0``. + Negative budgets or zero timeouts produce nonsensical run behavior.""" + with pytest.raises(ValidationError): + PlanStep(step_id="s1", action="agent_task", max_retries=-1) + with pytest.raises(ValidationError): + PlanStep(step_id="s1", action="agent_task", timeout_seconds=0) + + +def test_plan_step_rejects_unknown_field() -> None: + """``extra='forbid'`` catches typos like ``actoin`` (instead of + ``action``) before the bad step reaches storage.""" + with pytest.raises(ValidationError): + PlanStep.model_validate( + {"step_id": "s1", "action": "agent_task", "actoin": "agent_task"} + ) diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_trigger_spec.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_trigger_spec.py new file mode 100644 index 000000000..cf1a52466 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_trigger_spec.py @@ -0,0 +1,33 @@ +"""Lock the ``TriggerSpec`` validation contract — the entry shape used +inside an automation's ``triggers[]`` array. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.schemas.definition.trigger_spec import TriggerSpec + +pytestmark = pytest.mark.unit + + +def test_trigger_spec_accepts_type_with_default_empty_params() -> None: + """``type`` is required; ``params`` defaults to ``{}`` so triggers + that take no params don't need an explicit body.""" + spec = TriggerSpec(type="schedule") + + assert spec.type == "schedule" + assert spec.params == {} + + +def test_trigger_spec_rejects_empty_type() -> None: + """``type`` is the registry lookup key — empty would silently miss.""" + with pytest.raises(ValidationError): + TriggerSpec(type="") + + +def test_trigger_spec_rejects_unknown_field() -> None: + """``extra='forbid'`` catches typos at definition-validation time.""" + with pytest.raises(ValidationError): + TriggerSpec.model_validate({"type": "schedule", "paramz": {}}) diff --git a/surfsense_backend/tests/unit/automations/services/__init__.py b/surfsense_backend/tests/unit/automations/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py new file mode 100644 index 000000000..0bbff39dc --- /dev/null +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -0,0 +1,493 @@ +"""Lock creation-time model-policy enforcement in ``AutomationService``. + +Creation (REST + manual builder) rejects search spaces whose models aren't +billable for automations with HTTP 422, mirroring the runtime backstop. These +tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths +without touching the DB commit. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from fastapi import HTTPException + +import app.automations.services.automation as automation_mod +from app.automations.schemas.api import AutomationCreate, AutomationUpdate +from app.automations.schemas.definition.envelope import ( + AutomationDefinition, + AutomationModels, +) +from app.automations.schemas.definition.plan_step import PlanStep +from app.automations.services.automation import AutomationService +from app.automations.services.model_policy import AutomationModelPolicyError + +pytestmark = pytest.mark.unit + + +class _FakeSession: + def __init__(self, search_space: Any) -> None: + self._search_space = search_space + self.added: list[Any] = [] + self.commits = 0 + + async def get(self, _model: Any, _pk: int) -> Any: + return self._search_space + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.commits += 1 + + +def _service(search_space: Any) -> AutomationService: + return AutomationService( + session=_FakeSession(search_space), user=SimpleNamespace(id="u-1") + ) + + +def _definition(**kwargs: Any) -> AutomationDefinition: + return AutomationDefinition( + name="A", + plan=[PlanStep(step_id="s1", action="agent_task")], + **kwargs, + ) + + +async def test_assert_models_billable_raises_422_on_violation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A blocked model maps the policy error to HTTP 422.""" + + def _raise(_ss): + raise AutomationModelPolicyError( + [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] + ) + + monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) + + service = _service(SimpleNamespace(agent_llm_id=0)) + with pytest.raises(HTTPException) as exc_info: + await service._assert_models_billable(1) + + assert exc_info.value.status_code == 422 + + +async def test_assert_models_billable_raises_404_when_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A missing search space is a 404, not a policy error.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + service = _service(None) + with pytest.raises(HTTPException) as exc_info: + await service._assert_models_billable(999) + + assert exc_info.value.status_code == 404 + + +async def test_assert_models_billable_returns_search_space_when_ok( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When the policy accepts, the loaded search space is returned for reuse.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + search_space = SimpleNamespace(agent_llm_id=-1) + service = _service(search_space) + assert await service._assert_models_billable(1) is search_space + + +async def test_create_injects_captured_models_from_search_space( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create() snapshots the search space's model prefs onto the definition.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + + async def _return_added(self, _aid): + return self.session.added[-1] + + monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) + + search_space = SimpleNamespace( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + service = _service(search_space) + payload = AutomationCreate( + search_space_id=1, + name="A", + definition=_definition(), + ) + + automation = await service.create(payload) + + assert automation.definition["models"] == { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + +async def test_create_treats_unset_prefs_as_auto_zero( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``None`` search-space prefs are captured as ``0`` (Auto) ids.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + + async def _return_added(self, _aid): + return self.session.added[-1] + + monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) + + search_space = SimpleNamespace( + agent_llm_id=None, + image_generation_config_id=None, + vision_llm_config_id=None, + ) + service = _service(search_space) + payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) + + automation = await service.create(payload) + + assert automation.definition["models"] == { + "agent_llm_id": 0, + "image_generation_config_id": 0, + "vision_llm_config_id": 0, + } + + +async def test_create_honors_selected_models_when_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When the payload carries ``definition.models`` they are validated + kept. + + The search-space snapshot path is bypassed entirely (no + ``assert_automation_models_billable`` call). + """ + + def _fail_snapshot(_ss): + raise AssertionError("snapshot path should not run when models are provided") + + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", _fail_snapshot + ) + validated: dict[str, Any] = {} + + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + validated["ids"] = ( + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, + ) + + monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_added(self, _aid): + return self.session.added[-1] + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) + + service = _service(SimpleNamespace(agent_llm_id=-99)) + payload = AutomationCreate( + search_space_id=1, + name="A", + definition=_definition( + models=AutomationModels( + agent_llm_id=-1, + image_generation_config_id=7, + vision_llm_config_id=-2, + ) + ), + ) + + automation = await service.create(payload) + + assert validated["ids"] == (-1, 7, -2) + assert automation.definition["models"] == { + "agent_llm_id": -1, + "image_generation_config_id": 7, + "vision_llm_config_id": -2, + } + + +async def test_create_rejects_unbillable_selected_models( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A non-billable explicit selection maps the policy error to HTTP 422.""" + + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + raise AutomationModelPolicyError( + [{"kind": "llm", "config_id": -3, "reason": "free model"}] + ) + + monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) + + async def _noop_authorize(self, *_a, **_k): + return None + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + + service = _service(SimpleNamespace(agent_llm_id=-3)) + payload = AutomationCreate( + search_space_id=1, + name="A", + definition=_definition( + models=AutomationModels( + agent_llm_id=-3, + image_generation_config_id=7, + vision_llm_config_id=-2, + ) + ), + ) + + with pytest.raises(HTTPException) as exc_info: + await service.create(payload) + + assert exc_info.value.status_code == 422 + + +async def test_update_preserves_captured_models( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A definition edit carries over the previously captured ``models``.""" + captured = { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + existing = SimpleNamespace( + search_space_id=1, + definition={"name": "A", "plan": [], "models": captured}, + version=3, + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_existing(self, _aid): + return existing + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr( + AutomationService, "_get_with_triggers_or_raise", _return_existing + ) + + service = _service(SimpleNamespace()) + # The incoming patch definition has no ``models`` (frontend strips it). + patch = AutomationUpdate(definition=_definition()) + + result = await service.update(7, patch) + + assert result.definition["models"] == captured + assert result.version == 4 + + +async def test_update_honors_changed_models_when_valid( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A definition edit with a *changed* models block validates + keeps it.""" + existing = SimpleNamespace( + search_space_id=1, + definition={ + "name": "A", + "plan": [], + "models": { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + }, + }, + version=3, + ) + validated: dict[str, Any] = {} + + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + validated["ids"] = ( + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, + ) + + monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_existing(self, _aid): + return existing + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr( + AutomationService, "_get_with_triggers_or_raise", _return_existing + ) + + service = _service(SimpleNamespace()) + patch = AutomationUpdate( + definition=_definition( + models=AutomationModels( + agent_llm_id=-2, + image_generation_config_id=9, + vision_llm_config_id=-2, + ) + ) + ) + + result = await service.update(7, patch) + + assert validated["ids"] == (-2, 9, -2) + assert result.definition["models"] == { + "agent_llm_id": -2, + "image_generation_config_id": 9, + "vision_llm_config_id": -2, + } + assert result.version == 4 + + +async def test_update_rejects_changed_unbillable_models( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A *changed* non-billable models block is rejected with HTTP 422.""" + existing = SimpleNamespace( + search_space_id=1, + definition={ + "name": "A", + "plan": [], + "models": { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + }, + }, + version=3, + ) + + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + raise AutomationModelPolicyError( + [{"kind": "llm", "config_id": -7, "reason": "free model"}] + ) + + monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_existing(self, _aid): + return existing + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr( + AutomationService, "_get_with_triggers_or_raise", _return_existing + ) + + service = _service(SimpleNamespace()) + patch = AutomationUpdate( + definition=_definition( + models=AutomationModels( + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + ) + ) + + with pytest.raises(HTTPException) as exc_info: + await service.update(7, patch) + + assert exc_info.value.status_code == 422 + + +async def test_update_keeps_unchanged_models_without_revalidation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An unchanged models block is kept as-is and is NOT re-validated. + + Lets users edit an automation whose captured model later drifted out of + premium without an unrelated edit tripping the policy check. + """ + captured = { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + existing = SimpleNamespace( + search_space_id=1, + definition={"name": "A", "plan": [], "models": captured}, + version=3, + ) + + def _fail(*_a, **_k): + raise AssertionError("unchanged models must not be re-validated") + + monkeypatch.setattr(automation_mod, "assert_models_billable", _fail) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_existing(self, _aid): + return existing + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr( + AutomationService, "_get_with_triggers_or_raise", _return_existing + ) + + service = _service(SimpleNamespace()) + patch = AutomationUpdate( + definition=_definition(models=AutomationModels(**captured)) + ) + + result = await service.update(7, patch) + + assert result.definition["models"] == captured + assert result.version == 4 + + +async def test_model_eligibility_authorizes_and_returns_payload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``model_eligibility`` checks read access then returns the eligibility dict.""" + authorized: dict[str, Any] = {} + + async def _fake_check_permission(_session, _user, ss_id, permission, _msg): + authorized["search_space_id"] = ss_id + authorized["permission"] = permission + + monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission) + monkeypatch.setattr( + automation_mod, + "get_automation_model_eligibility", + lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, + ) + + service = _service(SimpleNamespace(agent_llm_id=-2)) + result = await service.model_eligibility(search_space_id=5) + + assert result == {"allowed": False, "violations": [{"kind": "image"}]} + assert authorized["search_space_id"] == 5 + assert authorized["permission"] == "automations:read" diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py new file mode 100644 index 000000000..2a471b4e9 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -0,0 +1,196 @@ +"""Lock the automation model-billing policy. + +Automations may only run on billable models: premium global configs +(``billing_tier == "premium"``) or user BYOK configs (positive id). Free +globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule +across all three model slots (chat LLM, image, vision). +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import app.automations.services.model_policy as model_policy +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, + get_automation_model_eligibility, + get_model_eligibility, +) + +pytestmark = pytest.mark.unit + + +def _search_space(*, llm: int | None, image: int | None, vision: int | None): + """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" + return SimpleNamespace( + agent_llm_id=llm, + image_generation_config_id=image, + vision_llm_config_id=vision, + ) + + +@pytest.fixture +def patched_globals(monkeypatch: pytest.MonkeyPatch): + """Stub the global config sources the policy consults for negative ids. + + Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. + """ + llm_configs = { + -1: {"id": -1, "billing_tier": "premium"}, + -2: {"id": -2, "billing_tier": "free"}, + } + monkeypatch.setattr( + "app.agents.new_chat.llm_config.load_global_llm_config_by_id", + lambda cid: llm_configs.get(cid), + ) + + from app.config import config as app_config + + monkeypatch.setattr( + app_config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + monkeypatch.setattr( + app_config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + return None + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: + """A positive config id is a user-owned BYOK model — always billable.""" + allowed, reason = model_policy._classify(kind, 7) + assert allowed is True + assert reason == "" + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("config_id", [0, None]) +def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: + """Auto mode (id 0) and an unset slot (None) are blocked.""" + allowed, reason = model_policy._classify(kind, config_id) + assert allowed is False + assert "Auto mode" in reason + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_premium_global_is_allowed(kind: str, patched_globals) -> None: + """A negative (global) id with premium billing tier is allowed.""" + allowed, reason = model_policy._classify(kind, -1) + assert allowed is True + assert reason == "" + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_free_global_is_blocked(kind: str, patched_globals) -> None: + """A negative (global) id with a free billing tier is blocked.""" + allowed, reason = model_policy._classify(kind, -2) + assert allowed is False + assert "free model" in reason + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: + """A negative id that resolves to no config is treated as not premium.""" + allowed, _ = model_policy._classify(kind, -999) + assert allowed is False + + +def test_eligibility_all_billable(patched_globals) -> None: + """Premium LLM + BYOK image + premium vision → allowed, no violations.""" + search_space = _search_space(llm=-1, image=5, vision=-1) + result = get_automation_model_eligibility(search_space) + assert result == {"allowed": True, "violations": []} + + +def test_eligibility_reports_each_violation(patched_globals) -> None: + """A free LLM, Auto image, and free vision each produce a violation.""" + search_space = _search_space(llm=-2, image=0, vision=-2) + result = get_automation_model_eligibility(search_space) + + assert result["allowed"] is False + kinds = {v["kind"] for v in result["violations"]} + assert kinds == {"llm", "image", "vision"} + # config_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} + + +def test_assert_raises_with_violations(patched_globals) -> None: + """``assert_automation_models_billable`` raises when any slot is blocked.""" + search_space = _search_space(llm=0, image=5, vision=-1) + with pytest.raises(AutomationModelPolicyError) as exc_info: + assert_automation_models_billable(search_space) + + assert len(exc_info.value.violations) == 1 + assert exc_info.value.violations[0]["kind"] == "llm" + + +def test_assert_passes_when_all_billable(patched_globals) -> None: + """No exception when every slot is premium or BYOK.""" + search_space = _search_space(llm=3, image=-1, vision=4) + assert assert_automation_models_billable(search_space) is None + + +# --- ID-based core (used by the runtime backstop against captured snapshots) --- + + +def test_get_model_eligibility_all_billable(patched_globals) -> None: + """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" + result = get_model_eligibility( + agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 + ) + assert result == {"allowed": True, "violations": []} + + +def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: + """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" + result = get_model_eligibility( + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + ) + assert result["allowed"] is False + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} + + +def test_assert_models_billable_raises(patched_globals) -> None: + """``assert_models_billable`` raises when any explicit id is blocked.""" + with pytest.raises(AutomationModelPolicyError) as exc_info: + assert_models_billable( + agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 + ) + assert len(exc_info.value.violations) == 1 + assert exc_info.value.violations[0]["kind"] == "llm" + + +def test_assert_models_billable_passes(patched_globals) -> None: + """No exception when every explicit id is premium or BYOK.""" + assert ( + assert_models_billable( + agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 + ) + is None + ) + + +def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: + """The search-space wrapper produces the same result as the ID core.""" + search_space = _search_space(llm=-2, image=0, vision=-2) + assert get_automation_model_eligibility(search_space) == get_model_eligibility( + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + ) diff --git a/surfsense_backend/tests/unit/automations/templating/__init__.py b/surfsense_backend/tests/unit/automations/templating/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/templating/test_context.py b/surfsense_backend/tests/unit/automations/templating/test_context.py new file mode 100644 index 000000000..54f372e77 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/templating/test_context.py @@ -0,0 +1,53 @@ +"""Lock the ``{run, inputs, steps}`` namespace exposed to every template.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import UUID + +import pytest + +from app.automations.templating.context import build_run_context + +pytestmark = pytest.mark.unit + + +def test_build_run_context_exposes_run_inputs_and_steps_namespaces() -> None: + """The namespace handed to templates groups run metadata under ``run``, + runtime + static inputs under ``inputs``, and step outputs (keyed by + ``output_as`` / ``step_id``) under ``steps``. Locks the contract that + every plan template body relies on.""" + creator = UUID("00000000-0000-0000-0000-000000000001") + started = datetime(2026, 5, 28, 14, 30, tzinfo=UTC) + + ctx = build_run_context( + run_id=42, + automation_id=7, + automation_name="Weekly digest", + automation_version=3, + search_space_id=1, + creator_id=creator, + trigger_id=11, + trigger_type="schedule", + started_at=started, + attempt=2, + inputs={"topic": "weekly"}, + step_outputs={"summarize": {"text": "ok"}}, + ) + + assert ctx == { + "run": { + "id": 42, + "automation_id": 7, + "automation_name": "Weekly digest", + "automation_version": 3, + "search_space_id": 1, + "creator_id": creator, + "trigger_id": 11, + "trigger_type": "schedule", + "started_at": started, + "attempt": 2, + }, + "inputs": {"topic": "weekly"}, + "steps": {"summarize": {"text": "ok"}}, + } diff --git a/surfsense_backend/tests/unit/automations/templating/test_environment.py b/surfsense_backend/tests/unit/automations/templating/test_environment.py new file mode 100644 index 000000000..64850c9c5 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/templating/test_environment.py @@ -0,0 +1,53 @@ +"""Lock the sandbox boundary: disallowed filters/tests reject, finalize coerces non-strings. + +These behaviors live in ``environment.py`` but are observed through the +public ``render_template`` surface — the same surface every step uses. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest +from jinja2.exceptions import TemplateError + +from app.automations.templating.render import render_template + +pytestmark = pytest.mark.unit + + +def test_environment_rejects_filters_not_in_the_allowlist() -> None: + """A template that pipes through a Jinja built-in **not** in the + allowlist (e.g. ``pprint``) must fail rather than rendering. Locks + the sandbox surface against accidental re-introduction of removed + filters.""" + with pytest.raises(TemplateError): + render_template("{{ value | pprint }}", {"value": {"k": 1}}) + + +def test_environment_finalizes_datetime_output_to_iso_string() -> None: + """A datetime that lands directly at an output site is stringified + via ``isoformat()`` rather than producing ``str(datetime)`` (which + has a space separator). Locks the wire shape templates produce + when emitting ``inputs.fired_at`` and other datetime values.""" + dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC) + + assert ( + render_template("{{ moment }}", {"moment": dt}) == "2026-05-28T14:30:00+00:00" + ) + + +def test_environment_finalizes_none_output_to_empty_string() -> None: + """A ``None`` at an output site becomes the empty string. Lets + templates write ``{{ inputs.last_fired_at }}`` unconditionally on + the first run without exploding on the null.""" + assert render_template("{{ missing }}", {"missing": None}) == "" + + +def test_environment_finalizes_dict_output_to_json() -> None: + """A dict at an output site is JSON-serialized. Same for lists. + Locks the wire shape so users embedding structured values into + prompts get deterministic, parseable output.""" + rendered = render_template("{{ payload }}", {"payload": {"a": 1, "b": [2, 3]}}) + + assert rendered == '{"a": 1, "b": [2, 3]}' diff --git a/surfsense_backend/tests/unit/automations/templating/test_filters.py b/surfsense_backend/tests/unit/automations/templating/test_filters.py new file mode 100644 index 000000000..cf83ee337 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/templating/test_filters.py @@ -0,0 +1,42 @@ +"""Lock the custom Jinja filters: ``date`` and ``slugify``.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from app.automations.templating.filters import filter_date, filter_slugify + +pytestmark = pytest.mark.unit + + +def test_filter_slugify_produces_url_safe_slug_from_typical_title() -> None: + """``filter_slugify`` lowercases, replaces non-alphanumerics with + hyphens, collapses repeats, and trims edge hyphens — the standard + URL-slug contract users expect when piping titles into paths.""" + assert filter_slugify("Hello, World! 2026") == "hello-world-2026" + + +def test_filter_date_formats_datetime_with_strftime_format() -> None: + """``filter_date`` calls ``strftime`` on datetime-like values with the + provided format. Default format yields ISO date (YYYY-MM-DD).""" + dt = datetime(2026, 5, 28, 14, 30, tzinfo=UTC) + + assert filter_date(dt) == "2026-05-28" + assert filter_date(dt, "%Y/%m/%d %H:%M") == "2026/05/28 14:30" + + +def test_filter_date_returns_empty_string_for_none() -> None: + """``None`` (e.g., a never-fired ``last_fired_at``) renders as the + empty string rather than the literal ``"None"`` or raising. This is + what lets templates write ``{{ inputs.last_fired_at | date }}`` + unconditionally on the first run.""" + assert filter_date(None) == "" + + +def test_filter_date_passes_strings_through_unchanged() -> None: + """Already-formatted ISO strings (the JSON-serialized shape of + runtime inputs like ``fired_at``) pass through unchanged so callers + don't have to special-case the type.""" + assert filter_date("2026-05-28T14:30:00+00:00") == "2026-05-28T14:30:00+00:00" diff --git a/surfsense_backend/tests/unit/automations/templating/test_render.py b/surfsense_backend/tests/unit/automations/templating/test_render.py new file mode 100644 index 000000000..42a7c7082 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/templating/test_render.py @@ -0,0 +1,59 @@ +"""Lock the public template-rendering surface: render, predicate, recursive.""" + +from __future__ import annotations + +import pytest +from jinja2 import UndefinedError + +from app.automations.templating.render import ( + evaluate_predicate, + render_template, + render_value, +) + +pytestmark = pytest.mark.unit + + +def test_render_template_substitutes_context_variables() -> None: + """A template referencing a context variable produces the substituted + string. Most basic contract of the template engine.""" + result = render_template("Hello {{ name }}!", {"name": "World"}) + + assert result == "Hello World!" + + +def test_render_template_raises_on_undefined_variable() -> None: + """Referencing a variable that isn't in the context raises rather than + rendering the empty string. Locks the StrictUndefined safety net so + template typos surface as run failures instead of silent corruption.""" + with pytest.raises(UndefinedError): + render_template("Hello {{ missing }}!", {}) + + +def test_evaluate_predicate_returns_truthy_outcome_of_expression() -> None: + """``evaluate_predicate`` compiles a Jinja **expression** (not template + body) and coerces the value to ``bool``. Drives ``step.when`` gating.""" + assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 3}}) is True + assert evaluate_predicate("inputs.count > 0", {"inputs": {"count": 0}}) is False + + +def test_render_value_renders_strings_recursively_through_dicts_and_lists() -> None: + """``render_value`` walks dicts and lists, renders string leaves through + the template engine, and leaves non-strings untouched. This is the + primitive ``execute_step`` uses to render step params at run time.""" + context = {"inputs": {"name": "World"}, "topic": "weekly"} + + rendered = render_value( + { + "greeting": "Hello {{ inputs.name }}", + "tags": ["{{ topic }}", "static"], + "config": {"retries": 3, "label": "{{ topic }}-{{ inputs.name }}"}, + }, + context, + ) + + assert rendered == { + "greeting": "Hello World", + "tags": ["weekly", "static"], + "config": {"retries": 3, "label": "weekly-World"}, + } diff --git a/surfsense_backend/tests/unit/automations/test_definition_types.py b/surfsense_backend/tests/unit/automations/test_definition_types.py new file mode 100644 index 000000000..2320b61d3 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/test_definition_types.py @@ -0,0 +1,56 @@ +"""Lock the ``params_schema`` derivation on action + trigger definitions. + +Both definition dataclasses expose ``params_schema`` as the JSON Schema +of their ``params_model``. This is what the registry endpoints surface +to the UI as the "what shape do these params take?" contract. +""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from app.automations.actions.types import ActionDefinition +from app.automations.triggers.types import TriggerDefinition + +pytestmark = pytest.mark.unit + + +class _Topic(BaseModel): + """Model with one required string field — minimal schema fingerprint.""" + + topic: str + + +def test_action_definition_params_schema_reflects_params_model() -> None: + """``ActionDefinition.params_schema`` returns a JSON Schema derived + from the Pydantic ``params_model`` — required fields and types are + visible to clients consuming the registry endpoint.""" + definition = ActionDefinition( + type="t", + name="N", + description="D", + params_model=_Topic, + build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value] + ) + + schema = definition.params_schema + + assert schema["type"] == "object" + assert schema["properties"]["topic"]["type"] == "string" + assert "topic" in schema["required"] + + +def test_trigger_definition_params_schema_reflects_params_model() -> None: + """Same JSON-Schema derivation contract on the trigger side.""" + definition = TriggerDefinition( + type="t", + description="D", + params_model=_Topic, + ) + + schema = definition.params_schema + + assert schema["type"] == "object" + assert schema["properties"]["topic"]["type"] == "string" + assert "topic" in schema["required"] diff --git a/surfsense_backend/tests/unit/automations/test_import_registrations.py b/surfsense_backend/tests/unit/automations/test_import_registrations.py new file mode 100644 index 000000000..35b1effa7 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/test_import_registrations.py @@ -0,0 +1,37 @@ +"""Lock the bundled import side-effects. + +Importing ``app.automations`` (the package) registers the v1 bundled +action (``agent_task``) and the v1 bundled trigger (``schedule``). If the +import chain breaks (e.g. someone removes ``from . import definition`` +in a sub-package ``__init__``), the system would silently launch with an +empty registry. These tests are the canary. +""" + +from __future__ import annotations + +import pytest + +import app.automations # noqa: F401 (force the package import + its side-effects) +from app.automations.actions.store import get_action +from app.automations.persistence.enums.trigger_type import TriggerType +from app.automations.triggers.store import get_trigger + +pytestmark = pytest.mark.unit + + +def test_bundled_agent_task_action_is_registered_after_package_import() -> None: + """``agent_task`` — the v1 default action — must be discoverable in + the registry after the package is imported.""" + definition = get_action("agent_task") + + assert definition is not None + assert definition.type == "agent_task" + + +def test_bundled_schedule_trigger_is_registered_after_package_import() -> None: + """``schedule`` — the only v1 trigger — must be discoverable in the + registry after the package is imported.""" + definition = get_trigger(TriggerType.SCHEDULE.value) + + assert definition is not None + assert definition.type == TriggerType.SCHEDULE.value diff --git a/surfsense_backend/tests/unit/automations/test_persistence_enums.py b/surfsense_backend/tests/unit/automations/test_persistence_enums.py new file mode 100644 index 000000000..23da613ed --- /dev/null +++ b/surfsense_backend/tests/unit/automations/test_persistence_enums.py @@ -0,0 +1,45 @@ +"""Lock the persistence enum string values + members. + +These enums are mirrored by Postgres enum types, embedded in stored DB +rows, and surfaced in the JSON API. Renaming a value (or removing a +member) silently breaks production data and previously-issued API +responses, so the strings + the set of members are the contract. +""" + +from __future__ import annotations + +import pytest + +from app.automations.persistence.enums.automation_status import AutomationStatus +from app.automations.persistence.enums.run_status import RunStatus +from app.automations.persistence.enums.trigger_type import TriggerType + +pytestmark = pytest.mark.unit + + +def test_automation_status_string_values_are_stable() -> None: + """The exact strings persisted to Postgres and served in API JSON.""" + assert {member.value for member in AutomationStatus} == { + "active", + "paused", + "archived", + } + + +def test_run_status_string_values_are_stable() -> None: + """Run lifecycle states embedded in the ``automation_runs`` table.""" + assert {member.value for member in RunStatus} == { + "pending", + "running", + "succeeded", + "failed", + "cancelled", + "timed_out", + } + + +def test_trigger_type_keeps_manual_member_even_though_unregistered() -> None: + """``schedule`` and ``event`` are registered; ``MANUAL`` is reserved + (mirrors the Postgres enum) but the trigger store does not register it. + The enum must keep every member so DB rows and migrations stay valid.""" + assert {member.value for member in TriggerType} == {"schedule", "event", "manual"} diff --git a/surfsense_backend/tests/unit/automations/test_stores.py b/surfsense_backend/tests/unit/automations/test_stores.py new file mode 100644 index 000000000..d005d7be7 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/test_stores.py @@ -0,0 +1,117 @@ +"""Lock the trigger + action registry contracts. + +Both stores share the same API shape (register/get/all + duplicate-raise), +so they're tested together to keep the contract visible side-by-side. +""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from app.automations.actions.store import ( + get_action, + register_action, +) +from app.automations.actions.types import ActionDefinition +from app.automations.triggers.store import ( + all_triggers, + get_trigger, + register_trigger, +) +from app.automations.triggers.types import TriggerDefinition + +pytestmark = pytest.mark.unit + + +class _Params(BaseModel): + """Empty params model used by test-only registrations.""" + + +def _trigger(type_: str = "test_trigger") -> TriggerDefinition: + return TriggerDefinition( + type=type_, description="Test trigger.", params_model=_Params + ) + + +def _action(type_: str = "test_action") -> ActionDefinition: + return ActionDefinition( + type=type_, + name="Test", + description="Test action.", + params_model=_Params, + build_handler=lambda _ctx: lambda _p: {}, # type: ignore[arg-type,return-value] + ) + + +def test_register_trigger_then_get_trigger_returns_the_same_definition( + isolated_trigger_registry: None, +) -> None: + """The canonical round-trip: register, look up by type, get the same + definition back. Locks the basic registry contract.""" + definition = _trigger() + register_trigger(definition) + + assert get_trigger("test_trigger") is definition + + +def test_register_action_then_get_action_returns_the_same_definition( + isolated_action_registry: None, +) -> None: + """Same round-trip contract for the action registry.""" + definition = _action() + register_action(definition) + + assert get_action("test_action") is definition + + +def test_get_trigger_returns_none_for_unknown_type( + isolated_trigger_registry: None, +) -> None: + """An unknown type returns ``None`` (not raises). Lets callers like + the dispatcher branch on "is this trigger still registered?" without + try/except.""" + assert get_trigger("never_registered") is None + + +def test_get_action_returns_none_for_unknown_type( + isolated_action_registry: None, +) -> None: + """Same ``None``-not-raise contract on the action side.""" + assert get_action("never_registered") is None + + +def test_register_trigger_rejects_duplicate_type( + isolated_trigger_registry: None, +) -> None: + """Re-registering the same ``type`` raises rather than silently + overwriting. Locks the safety net against accidental double-import + (e.g., circular imports re-running the registration block).""" + register_trigger(_trigger()) + + with pytest.raises(ValueError, match="test_trigger"): + register_trigger(_trigger()) + + +def test_register_action_rejects_duplicate_type( + isolated_action_registry: None, +) -> None: + """Same duplicate-rejection contract on the action side.""" + register_action(_action()) + + with pytest.raises(ValueError, match="test_action"): + register_action(_action()) + + +def test_all_triggers_returns_defensive_snapshot( + isolated_trigger_registry: None, +) -> None: + """``all_triggers()`` returns a copy: mutating the returned dict does + not corrupt the internal registry. Locks the snapshot contract that + UI/listing endpoints rely on.""" + register_trigger(_trigger("snapshot_test")) + + snapshot = all_triggers() + snapshot.pop("snapshot_test") + + assert get_trigger("snapshot_test") is not None diff --git a/surfsense_backend/tests/unit/automations/triggers/__init__.py b/surfsense_backend/tests/unit/automations/triggers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/__init__.py b/surfsense_backend/tests/unit/automations/triggers/builtin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/__init__.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_definition.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_definition.py new file mode 100644 index 000000000..479943cc2 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_definition.py @@ -0,0 +1,18 @@ +"""The ``event`` trigger self-registers on the triggers store at import.""" + +from __future__ import annotations + +import pytest + +from app.automations.triggers import get_trigger +from app.automations.triggers.builtin.event.params import EventTriggerParams + +pytestmark = pytest.mark.unit + + +def test_event_trigger_is_registered() -> None: + definition = get_trigger("event") + + assert definition is not None + assert definition.type == "event" + assert definition.params_model is EventTriggerParams diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_filter.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_filter.py new file mode 100644 index 000000000..9ddc3503a --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_filter.py @@ -0,0 +1,115 @@ +"""Behavior tests for the ``matches`` filter grammar.""" + +from __future__ import annotations + +import pytest + +from app.automations.triggers.builtin.event.filter import FilterError, matches + +pytestmark = pytest.mark.unit + + +def test_empty_filter_matches_any_payload() -> None: + assert matches({}, {"document_id": 42, "document_type": "FILE"}) is True + assert matches({}, {}) is True + + +def test_scalar_value_is_implicit_equality() -> None: + flt = {"document_type": "FILE"} + assert matches(flt, {"document_type": "FILE"}) is True + assert matches(flt, {"document_type": "WEBPAGE"}) is False + + +def test_multiple_fields_are_anded() -> None: + flt = {"document_type": "FILE", "search_space_id": 7} + assert matches(flt, {"document_type": "FILE", "search_space_id": 7}) is True + assert matches(flt, {"document_type": "FILE", "search_space_id": 9}) is False + + +def test_gt_operator_compares_greater_than() -> None: + flt = {"page_count": {"$gt": 10}} + assert matches(flt, {"page_count": 20}) is True + assert matches(flt, {"page_count": 10}) is False + assert matches(flt, {"page_count": 5}) is False + + +def test_remaining_comparison_operators() -> None: + assert matches({"n": {"$gte": 10}}, {"n": 10}) is True + assert matches({"n": {"$gte": 10}}, {"n": 9}) is False + + assert matches({"n": {"$lt": 10}}, {"n": 9}) is True + assert matches({"n": {"$lt": 10}}, {"n": 10}) is False + + assert matches({"n": {"$lte": 10}}, {"n": 10}) is True + assert matches({"n": {"$lte": 10}}, {"n": 11}) is False + + assert matches({"s": {"$eq": "FILE"}}, {"s": "FILE"}) is True + assert matches({"s": {"$eq": "FILE"}}, {"s": "WEB"}) is False + + assert matches({"s": {"$ne": "FILE"}}, {"s": "WEB"}) is True + assert matches({"s": {"$ne": "FILE"}}, {"s": "FILE"}) is False + + +def test_multiple_operators_on_one_field_are_anded() -> None: + flt = {"n": {"$gte": 10, "$lt": 20}} + assert matches(flt, {"n": 15}) is True + assert matches(flt, {"n": 10}) is True + assert matches(flt, {"n": 20}) is False + assert matches(flt, {"n": 5}) is False + + +def test_in_and_nin_membership_operators() -> None: + flt_in = {"document_type": {"$in": ["FILE", "WEBPAGE"]}} + assert matches(flt_in, {"document_type": "FILE"}) is True + assert matches(flt_in, {"document_type": "SLACK"}) is False + + flt_nin = {"document_type": {"$nin": ["FILE", "WEBPAGE"]}} + assert matches(flt_nin, {"document_type": "SLACK"}) is True + assert matches(flt_nin, {"document_type": "FILE"}) is False + + +def test_or_matches_when_any_branch_holds() -> None: + flt = {"$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}]} + assert matches(flt, {"document_type": "WEBPAGE"}) is True + assert matches(flt, {"document_type": "SLACK"}) is False + + +def test_and_matches_when_every_branch_holds() -> None: + flt = {"$and": [{"n": {"$gt": 5}}, {"n": {"$lt": 10}}]} + assert matches(flt, {"n": 7}) is True + assert matches(flt, {"n": 12}) is False + + +def test_not_inverts_its_subexpression() -> None: + flt = {"$not": {"document_type": "FILE"}} + assert matches(flt, {"document_type": "WEBPAGE"}) is True + assert matches(flt, {"document_type": "FILE"}) is False + + +def test_missing_field_never_matches_and_never_raises() -> None: + # Conservative: an absent field fails the constraint, and comparisons must + # not raise on the missing value — including $ne (absence isn't "not equal"). + assert matches({"document_type": "FILE"}, {}) is False + assert matches({"page_count": {"$gt": 5}}, {}) is False + assert matches({"document_type": {"$in": ["FILE"]}}, {}) is False + assert matches({"document_type": {"$ne": "FILE"}}, {}) is False + + +def test_logical_operators_compose_with_fields() -> None: + flt = { + "search_space_id": 7, + "$or": [{"document_type": "FILE"}, {"document_type": "WEBPAGE"}], + } + assert matches(flt, {"search_space_id": 7, "document_type": "FILE"}) is True + assert matches(flt, {"search_space_id": 9, "document_type": "FILE"}) is False + assert matches(flt, {"search_space_id": 7, "document_type": "SLACK"}) is False + + +def test_unknown_field_operator_raises_filter_error() -> None: + with pytest.raises(FilterError): + matches({"n": {"$regex": "x"}}, {"n": "xyz"}) + + +def test_unknown_logical_operator_raises_filter_error() -> None: + with pytest.raises(FilterError): + matches({"$nor": [{"document_type": "FILE"}]}, {"document_type": "FILE"}) diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_inputs.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_inputs.py new file mode 100644 index 000000000..e6191d7a7 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_inputs.py @@ -0,0 +1,26 @@ +"""An event hands its payload + metadata to the run as inputs.""" + +from __future__ import annotations + +import pytest + +from app.automations.triggers.builtin.event.inputs import event_runtime_inputs +from app.event_bus import Event + +pytestmark = pytest.mark.unit + + +def test_runtime_inputs_flatten_payload_with_event_metadata() -> None: + event = Event( + event_type="document.indexed", + payload={"document_id": 42, "document_type": "FILE"}, + search_space_id=7, + ) + + inputs = event_runtime_inputs(event) + + assert inputs["document_id"] == 42 + assert inputs["document_type"] == "FILE" + assert inputs["event_type"] == "document.indexed" + assert inputs["event_id"] == event.event_id + assert inputs["occurred_at"] == event.occurred_at.isoformat() diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_match.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_match.py new file mode 100644 index 000000000..d83db97a4 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_match.py @@ -0,0 +1,39 @@ +"""Which triggers an event fires: event_type equality + filter match.""" + +from __future__ import annotations + +import pytest + +from app.automations.triggers.builtin.event.match import trigger_matches_event +from app.event_bus import Event + +pytestmark = pytest.mark.unit + + +def _event(event_type: str = "document.indexed", **payload) -> Event: + return Event(event_type=event_type, payload=payload, search_space_id=7) + + +def test_matches_when_event_type_equal_and_filter_passes() -> None: + params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}} + assert trigger_matches_event(params, _event(document_type="FILE")) is True + + +def test_no_match_when_event_type_differs() -> None: + params = {"event_type": "document.indexed", "filter": {}} + assert trigger_matches_event(params, _event("podcast.generated")) is False + + +def test_no_match_when_filter_rejects_payload() -> None: + params = {"event_type": "document.indexed", "filter": {"document_type": "FILE"}} + assert trigger_matches_event(params, _event(document_type="WEBPAGE")) is False + + +def test_empty_filter_matches_any_payload_of_that_type() -> None: + params = {"event_type": "document.indexed", "filter": {}} + assert trigger_matches_event(params, _event(document_type="ANYTHING")) is True + + +def test_missing_filter_key_is_treated_as_empty() -> None: + params = {"event_type": "document.indexed"} + assert trigger_matches_event(params, _event(document_type="X")) is True diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_params.py b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_params.py new file mode 100644 index 000000000..fef3b0b94 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/event/test_params.py @@ -0,0 +1,40 @@ +"""``EventTriggerParams`` contract: an event_type to listen for + an optional filter.""" + +from __future__ import annotations + +import pytest + +from app.automations.triggers.builtin.event.params import EventTriggerParams + +pytestmark = pytest.mark.unit + + +def test_accepts_event_type_and_filter() -> None: + params = EventTriggerParams( + event_type="document.indexed", + filter={"document_type": "FILE"}, + ) + + assert params.event_type == "document.indexed" + assert params.filter == {"document_type": "FILE"} + + +def test_filter_defaults_to_empty() -> None: + params = EventTriggerParams(event_type="document.indexed") + + assert params.filter == {} + + +def test_event_type_is_required() -> None: + with pytest.raises(ValueError): + EventTriggerParams(filter={"x": 1}) + + +def test_event_type_must_not_be_blank() -> None: + with pytest.raises(ValueError): + EventTriggerParams(event_type="") + + +def test_extra_keys_are_forbidden() -> None: + with pytest.raises(ValueError): + EventTriggerParams(event_type="document.indexed", typo=True) diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/__init__.py b/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_cron.py b/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_cron.py new file mode 100644 index 000000000..618b82f2a --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_cron.py @@ -0,0 +1,86 @@ +"""Lock the cron + timezone + UTC normalization contract.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from app.automations.triggers.builtin.schedule.cron import ( + InvalidCronError, + compute_next_fire_at, + validate_cron, +) + +pytestmark = pytest.mark.unit + + +def test_compute_next_fire_at_returns_next_match_normalized_to_utc() -> None: + """``compute_next_fire_at`` evaluates the cron in the given IANA timezone + and returns the next strictly-later match expressed in UTC. + + Setup: ``0 9 * * 1-5`` (09:00 Monday-Friday) in ``Africa/Kigali`` + (UTC+2, no DST). With ``after`` = Tuesday 05:00 UTC (= 07:00 local), + the next fire is the same Tuesday at 09:00 local = 07:00 UTC. + """ + after = datetime(2026, 5, 26, 5, 0, tzinfo=UTC) # Tue 07:00 Kigali + + next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after) + + assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC) + + +def test_compute_next_fire_at_respects_dst_offset_change() -> None: + """A daily cron in a DST-observing tz fires at the same local hour + across the DST boundary, which produces a different UTC offset on + either side of the transition. + + Setup: ``0 9 * * *`` (09:00 every day) in ``America/New_York``. + NY is UTC-5 in winter (EST), UTC-4 in summer (EDT). Evaluating from + each side of the spring-forward in 2026 (Sun Mar 8 at 02:00 → 03:00): + + - winter: ``after`` = 2026-02-15 (EST, UTC-5) → next 09:00 EST = 14:00 UTC + - summer: ``after`` = 2026-04-15 (EDT, UTC-4) → next 09:00 EDT = 13:00 UTC + """ + winter_after = datetime(2026, 2, 15, 0, 0, tzinfo=UTC) + summer_after = datetime(2026, 4, 15, 0, 0, tzinfo=UTC) + + winter_fire = compute_next_fire_at( + "0 9 * * *", "America/New_York", after=winter_after + ) + summer_fire = compute_next_fire_at( + "0 9 * * *", "America/New_York", after=summer_after + ) + + assert winter_fire == datetime(2026, 2, 15, 14, 0, tzinfo=UTC) + assert summer_fire == datetime(2026, 4, 15, 13, 0, tzinfo=UTC) + + +def test_compute_next_fire_at_is_strictly_after_when_after_equals_a_match() -> None: + """When ``after`` lands exactly on a cron match, the result jumps to the + next match — never the same instant. Required so the schedule-tick + can pass ``next_fire_at`` itself as ``after`` to advance to the + following slot without double-firing. + + Setup: weekday 09:00 Kigali. ``after`` = Mon 09:00 Kigali = 07:00 UTC + (an exact match) → next fire must be Tue 09:00 Kigali = next day 07:00 UTC. + """ + after = datetime(2026, 5, 25, 7, 0, tzinfo=UTC) # Mon 09:00 Kigali — exact match + + next_fire = compute_next_fire_at("0 9 * * 1-5", "Africa/Kigali", after=after) + + assert next_fire == datetime(2026, 5, 26, 7, 0, tzinfo=UTC) # Tue 09:00 Kigali + + +def test_validate_cron_rejects_malformed_cron_expression() -> None: + """A syntactically invalid cron must be rejected at validation time so + bad triggers can't reach storage and explode at fire time.""" + with pytest.raises(InvalidCronError): + validate_cron("this is not cron", "UTC") + + +def test_validate_cron_rejects_unknown_timezone() -> None: + """A non-IANA timezone string must be rejected at validation time — + the same protective gate as the cron expression itself.""" + with pytest.raises(InvalidCronError): + validate_cron("0 9 * * *", "Mars/Olympus_Mons") diff --git a/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_params.py b/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_params.py new file mode 100644 index 000000000..bd9ebc621 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/triggers/builtin/schedule/test_params.py @@ -0,0 +1,34 @@ +"""Lock the ``ScheduleTriggerParams`` validation contract.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.automations.triggers.builtin.schedule.params import ScheduleTriggerParams + +pytestmark = pytest.mark.unit + + +def test_schedule_params_accept_valid_cron_and_iana_timezone() -> None: + """A well-formed cron + IANA timezone yields a populated model. + Locks the round-trip path users go through when creating a trigger.""" + params = ScheduleTriggerParams(cron="0 9 * * 1-5", timezone="Africa/Kigali") + + assert params.cron == "0 9 * * 1-5" + assert params.timezone == "Africa/Kigali" + + +def test_schedule_params_reject_malformed_cron_with_validation_error() -> None: + """``InvalidCronError`` from ``validate_cron`` must surface as + Pydantic ``ValidationError`` so the FastAPI layer returns 422 instead + of letting the bad value reach storage.""" + with pytest.raises(ValidationError): + ScheduleTriggerParams(cron="not cron", timezone="UTC") + + +def test_schedule_params_reject_unknown_timezone_with_validation_error() -> None: + """An unknown timezone is rejected at the API boundary — same gate + as the cron expression itself.""" + with pytest.raises(ValidationError): + ScheduleTriggerParams(cron="0 9 * * *", timezone="Mars/Olympus_Mons") diff --git a/surfsense_backend/tests/unit/event_bus/__init__.py b/surfsense_backend/tests/unit/event_bus/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/event_bus/conftest.py b/surfsense_backend/tests/unit/event_bus/conftest.py new file mode 100644 index 000000000..81ba4e464 --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/conftest.py @@ -0,0 +1,25 @@ +"""Shared fixtures for the ``app.event_bus`` unit-test tree. + +The event-type catalog is a module-level registry populated at import. Tests +that register their own event types (or assert on registry contents) snapshot +and restore it so state never leaks between tests. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import pytest + +from app.event_bus.catalog import catalog + + +@pytest.fixture +def isolated_event_catalog() -> Iterator[None]: + """Snapshot and restore the event-type catalog around a test.""" + snapshot = dict(catalog._registry) + try: + yield + finally: + catalog._registry.clear() + catalog._registry.update(snapshot) diff --git a/surfsense_backend/tests/unit/event_bus/test_bus.py b/surfsense_backend/tests/unit/event_bus/test_bus.py new file mode 100644 index 000000000..6c970760f --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/test_bus.py @@ -0,0 +1,181 @@ +"""``EventBus`` contract: subscribe, publish (stamp + fan out), dispatch. + +Each test uses a fresh ``EventBus`` — no shared global state. +""" + +from __future__ import annotations + +import pytest + +from app.event_bus import Event, EventBus + +pytestmark = pytest.mark.unit + + +def _event() -> Event: + return Event(event_type="x.happened", payload={"k": "v"}, search_space_id=1) + + +async def _noop(_event: Event) -> None: + return None + + +async def _other(_event: Event) -> None: + return None + + +# --- registry ------------------------------------------------------------- + + +def test_subscribe_then_subscribers_returns_the_handler() -> None: + bus = EventBus() + bus.subscribe(_noop) + + assert _noop in bus.subscribers() + + +def test_subscribe_is_idempotent_for_the_same_handler() -> None: + """Registering the same handler twice must not make it fire twice.""" + bus = EventBus() + bus.subscribe(_noop) + bus.subscribe(_noop) + + assert bus.subscribers().count(_noop) == 1 + + +def test_distinct_handlers_both_register() -> None: + bus = EventBus() + bus.subscribe(_noop) + bus.subscribe(_other) + + registered = bus.subscribers() + assert _noop in registered + assert _other in registered + + +def test_subscribers_returns_a_defensive_snapshot() -> None: + """Mutating the returned list must not corrupt the registry.""" + bus = EventBus() + bus.subscribe(_noop) + + snapshot = bus.subscribers() + snapshot.clear() + + assert _noop in bus.subscribers() + + +def test_subscribe_returns_handler_so_it_can_be_used_as_a_decorator() -> None: + bus = EventBus() + returned = bus.subscribe(_other) + + assert returned is _other + + +def test_two_buses_do_not_share_subscribers() -> None: + """The registry is per-instance, not global.""" + a = EventBus() + b = EventBus() + a.subscribe(_noop) + + assert _noop in a.subscribers() + assert _noop not in b.subscribers() + + +# --- dispatch ------------------------------------------------------------- + + +async def test_dispatch_delivers_event_to_every_subscriber() -> None: + bus = EventBus() + seen: list[tuple[str, Event]] = [] + + async def first(event: Event) -> None: + seen.append(("first", event)) + + async def second(event: Event) -> None: + seen.append(("second", event)) + + bus.subscribe(first) + bus.subscribe(second) + + event = _event() + await bus.dispatch(event) + + assert ("first", event) in seen + assert ("second", event) in seen + + +async def test_dispatch_isolates_a_failing_subscriber() -> None: + """A subscriber that raises must not stop a healthy one from running.""" + bus = EventBus() + healthy_ran = False + + async def boom(_event: Event) -> None: + raise RuntimeError("subscriber blew up") + + async def healthy(_event: Event) -> None: + nonlocal healthy_ran + healthy_ran = True + + bus.subscribe(boom) + bus.subscribe(healthy) + + await bus.dispatch(_event()) + + assert healthy_ran is True + + +async def test_dispatch_never_propagates_subscriber_errors() -> None: + """``dispatch`` itself must not raise even if every subscriber fails.""" + bus = EventBus() + + async def boom(_event: Event) -> None: + raise ValueError("nope") + + bus.subscribe(boom) + + await bus.dispatch(_event()) # must not raise + + +async def test_dispatch_with_no_subscribers_is_a_noop() -> None: + bus = EventBus() + await bus.dispatch(_event()) # must not raise + + +# --- publish -------------------------------------------------------------- + + +async def test_publish_builds_a_stamped_event_and_fans_it_out() -> None: + bus = EventBus() + received: list[Event] = [] + + async def handler(event: Event) -> None: + received.append(event) + + bus.subscribe(handler) + await bus.publish("document.indexed", {"document_id": 42}, search_space_id=7) + + assert len(received) == 1 + event = received[0] + assert event.event_type == "document.indexed" + assert event.payload == {"document_id": 42} + assert event.search_space_id == 7 + # Engine-stamped identity/time on the way through. + assert event.event_id + assert event.occurred_at + + +async def test_publish_defaults_payload_to_empty_dict() -> None: + bus = EventBus() + received: list[Event] = [] + + async def handler(event: Event) -> None: + received.append(event) + + bus.subscribe(handler) + await bus.publish("x.happened", search_space_id=1) + + assert received[0].payload == {} + + +async def test_publish_with_no_subscribers_is_a_noop() -> None: + await EventBus().publish("x.happened", search_space_id=1) # must not raise diff --git a/surfsense_backend/tests/unit/event_bus/test_catalog.py b/surfsense_backend/tests/unit/event_bus/test_catalog.py new file mode 100644 index 000000000..b09482bea --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/test_catalog.py @@ -0,0 +1,77 @@ +"""EventCatalog contract: register, look up, snapshot, derive schema.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from app.event_bus.catalog import EventCatalog, EventType + +pytestmark = pytest.mark.unit + + +class _SamplePayload(BaseModel): + document_id: int + + +def _event_type(type_: str = "test.thing") -> EventType: + return EventType( + type=type_, + description="A thing happened.", + payload_model=_SamplePayload, + ) + + +def test_register_then_get_returns_the_event_type(isolated_event_catalog: None) -> None: + from app.event_bus.catalog import catalog + + catalog.register(_event_type()) + + assert catalog.get("test.thing") is not None + assert catalog.get("test.thing").type == "test.thing" + + +def test_get_unknown_type_returns_none(isolated_event_catalog: None) -> None: + from app.event_bus.catalog import catalog + + assert catalog.get("does.not.exist") is None + + +def test_register_duplicate_type_raises(isolated_event_catalog: None) -> None: + """A type is a contract; registering it twice is a bug, not an override.""" + from app.event_bus.catalog import catalog + + catalog.register(_event_type()) + + with pytest.raises(ValueError, match="already registered"): + catalog.register(_event_type()) + + +def test_all_is_a_defensive_snapshot(isolated_event_catalog: None) -> None: + """Mutating the returned dict must not corrupt the registry.""" + from app.event_bus.catalog import catalog + + catalog.register(_event_type()) + + snapshot = catalog.all() + snapshot.clear() + + assert catalog.get("test.thing") is not None + + +def test_payload_schema_is_derived_from_the_payload_model() -> None: + """The JSON Schema a UI/validator consumes comes from the payload model.""" + event_type = _event_type() + + assert event_type.payload_schema == _SamplePayload.model_json_schema() + + +def test_each_catalog_instance_has_its_own_registry() -> None: + """Two EventCatalog instances are fully independent.""" + a = EventCatalog() + b = EventCatalog() + + a.register(_event_type()) + + assert a.get("test.thing") is not None + assert b.get("test.thing") is None diff --git a/surfsense_backend/tests/unit/event_bus/test_document_entered_folder.py b/surfsense_backend/tests/unit/event_bus/test_document_entered_folder.py new file mode 100644 index 000000000..6044b539e --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/test_document_entered_folder.py @@ -0,0 +1,56 @@ +"""``document.entered_folder`` payload contract + catalog registration.""" + +from __future__ import annotations + +import pytest + +from app.event_bus.catalog import catalog +from app.event_bus.events.document_entered_folder import ( + EVENT_TYPE, + DocumentEnteredFolderPayload, +) + +pytestmark = pytest.mark.unit + + +def _payload(**overrides: object) -> DocumentEnteredFolderPayload: + base: dict[str, object] = { + "document_id": 42, + "folder_id": 7, + "document_type": "FILE", + "title": "Q3 report.pdf", + } + base.update(overrides) + return DocumentEnteredFolderPayload(**base) + + +def test_payload_carries_the_filterable_fields() -> None: + payload = _payload(connector_id=12, created_by_id="abc") + + assert payload.document_id == 42 + assert payload.folder_id == 7 + assert payload.document_type == "FILE" + assert payload.connector_id == 12 + + +def test_first_placement_is_not_a_move() -> None: + """No previous folder (created or AI-sorted into place) → not a move.""" + assert _payload(previous_folder_id=None).is_move is False + + +def test_change_between_folders_is_a_move() -> None: + assert _payload(previous_folder_id=3).is_move is True + + +def test_is_move_is_serialized_for_filtering() -> None: + """Filters match against the dumped payload, so ``is_move`` must appear there.""" + dumped = _payload(previous_folder_id=3).model_dump() + + assert dumped["is_move"] is True + + +def test_event_type_is_registered_in_the_catalog() -> None: + registered = catalog.get(EVENT_TYPE) + + assert registered is not None + assert registered.payload_model is DocumentEnteredFolderPayload diff --git a/surfsense_backend/tests/unit/event_bus/test_entered_folder_predicate.py b/surfsense_backend/tests/unit/event_bus/test_entered_folder_predicate.py new file mode 100644 index 000000000..1f71e3abb --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/test_entered_folder_predicate.py @@ -0,0 +1,58 @@ +"""payload_if_entered_folder: decides whether a document commit warrants an event.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.event_bus.events.document_entered_folder import payload_if_entered_folder + +pytestmark = pytest.mark.unit + + +def _call(**overrides: Any) -> dict[str, Any] | None: + defaults: dict[str, Any] = { + "document_id": 1, + "search_space_id": 10, + "new_folder_id": 7, + "previous_folder_id": None, + "folder_id_changed": True, + "status_state": "ready", + "document_type": "FILE", + "title": "report.pdf", + "connector_id": None, + "created_by_id": None, + } + defaults.update(overrides) + return payload_if_entered_folder(**defaults) + + +def test_folder_set_ready_fires() -> None: + result = _call() + + assert result is not None + assert result["event_type"] == "document.entered_folder" + assert result["search_space_id"] == 10 + assert result["payload"]["folder_id"] == 7 + assert result["payload"]["previous_folder_id"] is None + + +def test_no_folder_is_silent() -> None: + assert _call(new_folder_id=None) is None + + +def test_not_ready_is_silent() -> None: + assert _call(status_state="processing") is None + + +def test_folder_unchanged_is_silent() -> None: + assert _call(folder_id_changed=False) is None + + +def test_move_carries_previous_folder_id() -> None: + result = _call(previous_folder_id=3, new_folder_id=7) + + assert result is not None + assert result["payload"]["previous_folder_id"] == 3 + assert result["payload"]["folder_id"] == 7 diff --git a/surfsense_backend/tests/unit/event_bus/test_event.py b/surfsense_backend/tests/unit/event_bus/test_event.py new file mode 100644 index 000000000..d09cb4364 --- /dev/null +++ b/surfsense_backend/tests/unit/event_bus/test_event.py @@ -0,0 +1,53 @@ +"""``Event`` contract: carry caller facts + engine-stamped id/time, round-trip JSON.""" + +from __future__ import annotations + +from datetime import datetime + +import pytest + +from app.event_bus.event import Event + +pytestmark = pytest.mark.unit + + +def test_event_carries_caller_supplied_facts() -> None: + """The three caller inputs are stored verbatim.""" + event = Event( + event_type="document.indexed", + payload={"document_id": 42, "content_type": "pdf"}, + search_space_id=7, + ) + + assert event.event_type == "document.indexed" + assert event.payload == {"document_id": 42, "content_type": "pdf"} + assert event.search_space_id == 7 + + +def test_event_stamps_identity_and_time_when_not_supplied() -> None: + """Engine stamps id + time so subscribers can dedup/order.""" + event = Event(event_type="x.happened", payload={}, search_space_id=1) + + assert event.event_id + assert isinstance(event.occurred_at, datetime) + + +def test_event_ids_are_unique_per_instance() -> None: + """Two events published with identical content are still distinct facts.""" + first = Event(event_type="x.happened", payload={}, search_space_id=1) + second = Event(event_type="x.happened", payload={}, search_space_id=1) + + assert first.event_id != second.event_id + + +def test_event_survives_json_round_trip() -> None: + """Serialize → deserialize reproduces the event (subscribers queue it as JSON).""" + original = Event( + event_type="podcast.generated", + payload={"podcast_id": 9, "duration_s": 123.5}, + search_space_id=3, + ) + + restored = Event.model_validate_json(original.model_dump_json()) + + assert restored == original diff --git a/surfsense_backend/tests/unit/observability/test_helpers.py b/surfsense_backend/tests/unit/observability/test_helpers.py new file mode 100644 index 000000000..ae60c1939 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/test_helpers.py @@ -0,0 +1,101 @@ +"""Tests for pure observability helper functions.""" + +from __future__ import annotations + +import pytest + +from app.observability import metrics as ot_metrics, otel as ot + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _disable_otel(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + ot.reload_for_tests() + yield + ot.reload_for_tests() + + +@pytest.mark.parametrize( + ("task_name", "expected"), + [ + ("reindex_document", "reindex"), + ("delete_document_background", "delete"), + ("delete_folder_documents_background", "delete"), + ("delete_search_space_background", "delete"), + ("process_extension_document", "process"), + ("process_youtube_video", "process"), + ("process_file_upload", "process"), + ("process_file_upload_with_document", "process"), + ("process_circleback_meeting", "process"), + ("generate_video_presentation", "generate"), + ("generate_content_podcast", "generate"), + ("cleanup_stale_indexing_notifications", "cleanup"), + ("reconcile_pending_stripe_page_purchases", "reconcile"), + ("reconcile_pending_stripe_token_purchases", "reconcile"), + ("check_periodic_schedules", "check"), + ("ai_sort_search_space", "ai"), + ("index_notion_pages", "index"), + ("index_github_repos", "index"), + ("index_google_drive_files", "index"), + ("index_composio_connector", "index"), + ("index_obsidian_attachment", "index"), + ("index_local_folder", "index"), + ("index_uploaded_folder_files", "index"), + ("noseparator", "noseparator"), + ("", "unknown"), + ], +) +def test_parse_celery_task_label(task_name: str, expected: str) -> None: + assert ot_metrics.parse_celery_task_label(task_name) == expected + + +def test_parse_celery_task_label_handles_none() -> None: + assert ot_metrics.parse_celery_task_label(None) == "unknown" + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (type("RateLimitError", (Exception,), {})(), "rate_limited"), + (type("AuthenticationError", (Exception,), {})(), "auth_failed"), + (type("QuotaInsufficientError", (Exception,), {})(), "quota_exhausted"), + (TimeoutError(), "timeout"), + (type("APIConnectionError", (Exception,), {})(), "network_failed"), + (type("ServiceUnavailableError", (Exception,), {})(), "server_error"), + (type("LockContentionError", (Exception,), {})(), "lock_contention"), + (type("UnsupportedFormatError", (Exception,), {})(), "unsupported_format"), + (type("ProviderError", (Exception,), {})(), "provider_error"), + (RuntimeError("plain"), "unknown"), + ], +) +def test_categorize_exception(exc: BaseException, expected: str) -> None: + assert ot_metrics.categorize_exception(exc) == expected + + +def test_record_celery_queue_latency_noops_when_disabled() -> None: + ot_metrics.record_celery_queue_latency( + 0.5, + task_name="index_notion_pages", + queue="surfsense.connectors", + scheduled=False, + operation="index", + ) + + +def test_add_event_noops_when_disabled() -> None: + ot.add_event("test.event", {"value": 1}) + + +def test_add_event_noops_without_current_span(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeTrace: + @staticmethod + def get_current_span(): + return None + + monkeypatch.setattr(ot, "_ENABLED", True) + monkeypatch.setattr(ot, "_ot_trace", FakeTrace()) + + ot.add_event("test.event", {"value": 1}) diff --git a/surfsense_backend/tests/unit/observability/test_otel.py b/surfsense_backend/tests/unit/observability/test_otel.py index fc5813973..d3718e7b9 100644 --- a/surfsense_backend/tests/unit/observability/test_otel.py +++ b/surfsense_backend/tests/unit/observability/test_otel.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.observability import otel +from app.observability import bootstrap, metrics, otel pytestmark = pytest.mark.unit @@ -12,7 +12,14 @@ pytestmark = pytest.mark.unit @pytest.fixture(autouse=True) def _reset_otel_state(monkeypatch: pytest.MonkeyPatch): """Force a clean OTel disabled state per test, then restore after.""" - for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"): + for env in ( + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_PROTOCOL", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "SURFSENSE_DISABLE_OTEL", + "OTEL_SDK_DISABLED", + ): monkeypatch.delenv(env, raising=False) monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") otel.reload_for_tests() @@ -36,6 +43,195 @@ def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None assert otel.reload_for_tests() is False +def test_spec_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_SDK_DISABLED", "true") + assert otel.reload_for_tests() is False + + +class TestBootstrapConfig: + def test_disabled_checks_both_kill_switches( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + assert bootstrap.is_otel_disabled() is False + + monkeypatch.setenv("OTEL_SDK_DISABLED", "on") + assert bootstrap.is_otel_disabled() is True + + def test_configured_by_shared_or_signal_endpoint( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + assert bootstrap.is_otel_configured() is False + + monkeypatch.setenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317" + ) + assert bootstrap.is_otel_configured() is True + + def test_init_otel_noops_when_disabled( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + called = {"traces": False} + + def fake_init_traces(app=None): + del app + called["traces"] = True + + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setattr(bootstrap, "init_traces", fake_init_traces) + + bootstrap.init_otel() + assert called["traces"] is False + + def test_init_otel_dispatches_enabled_signals( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + called: list[str] = [] + + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setattr( + bootstrap, "init_traces", lambda app=None: called.append("traces") + ) + monkeypatch.setattr(bootstrap, "init_metrics", lambda: called.append("metrics")) + monkeypatch.setattr(bootstrap, "init_logs", lambda: called.append("logs")) + + bootstrap.init_otel() + assert called == ["traces", "metrics", "logs"] + + def test_resource_defaults_include_service_metadata( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OTEL_SERVICE_NAME", "custom-backend") + monkeypatch.setenv("SURFSENSE_ENV", "test") + + resource = bootstrap._build_resource() + attrs = dict(resource.attributes) + assert attrs["service.name"] == "custom-backend" + assert attrs["deployment.environment.name"] == "test" + assert attrs["deployment.environment"] == "test" + assert attrs["service.instance.id"] + + def test_deployment_environment_uses_surfsense_env_only( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_ENV", raising=False) + + assert bootstrap._deployment_environment() == "dev" + + monkeypatch.setenv("SURFSENSE_ENV", "production") + + assert bootstrap._deployment_environment() == "production" + + def test_shutdown_is_safe_without_providers(self) -> None: + bootstrap.shutdown_otel() + + def test_init_logs_enables_log_correlation( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + calls: list[dict[str, object]] = [] + + class FakeLoggingInstrumentor: + def instrument(self, **kwargs: object) -> None: + calls.append(kwargs) + + def fake_safe_instrument(name: str, callback): + assert name == "logging" + monkeypatch.setattr( + "opentelemetry.instrumentation.logging.LoggingInstrumentor", + FakeLoggingInstrumentor, + ) + callback() + return True + + monkeypatch.setattr(bootstrap, "_LOGS_INITIALIZED", False) + monkeypatch.setattr(bootstrap, "_safe_instrument", fake_safe_instrument) + + bootstrap.init_logs() + + assert calls == [{"set_logging_format": True}] + + +class TestMetricHelpers: + def test_all_metric_helpers_noop_safely_when_disabled(self) -> None: + metrics.record_model_call_duration(12.5, model="gpt-4o", provider="openai") + metrics.record_model_token_usage( + input_tokens=10, + output_tokens=5, + model="gpt-4o", + provider="openai", + ) + metrics.record_tool_call_duration(3.0, tool_name="web_search") + metrics.record_tool_call_error(tool_name="web_search") + metrics.record_kb_search_duration( + 4.0, + search_space_id=1, + surface="documents", + ) + metrics.record_compaction_run(reason="auto") + metrics.record_permission_ask(permission="write_file") + metrics.record_interrupt(interrupt_type="permission_ask") + metrics.record_indexing_document_duration(1.2, document_type="FILE") + metrics.record_indexing_document_outcome(document_type="FILE", status="success") + metrics.record_connector_sync_duration( + 2.3, + connector_type="index_notion_pages", + ) + metrics.record_connector_sync_outcome( + connector_type="index_notion_pages", + status="success", + ) + metrics.record_auth_failure(reason="UNAUTHORIZED") + metrics.record_rate_limit_rejection(scope="login") + metrics.record_perf_elapsed(7.0, label="[test]") + + def test_runtime_observables_register_once( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + class FakeMeter: + def __init__(self) -> None: + self.names: list[str] = [] + + def create_observable_gauge(self, name: str, **kwargs) -> None: + del kwargs + self.names.append(name) + + fake_meter = FakeMeter() + monkeypatch.setattr(metrics, "_OBSERVABLES_REGISTERED", False) + monkeypatch.setattr(metrics, "_is_enabled", lambda: True) + monkeypatch.setattr(metrics, "_get_meter", lambda: fake_meter) + + metrics.register_runtime_observables() + metrics.register_runtime_observables() + + assert len(fake_meter.names) == 6 + assert fake_meter.names.count("python.asyncio.tasks") == 1 + monkeypatch.setattr(metrics, "_OBSERVABLES_REGISTERED", False) + + +def test_log_record_factory_provides_zero_otel_fields() -> None: + import logging + + import main # noqa: F401 + + record = logging.getLogRecordFactory()( + "test", + logging.INFO, + __file__, + 1, + "hello", + (), + None, + ) + assert record.otelTraceID == "0" + assert record.otelSpanID == "0" + + class TestNoopSpansWhenDisabled: def test_generic_span_yields_noop(self) -> None: with otel.span("any.thing", attributes={"x": 1}) as sp: diff --git a/surfsense_backend/tests/unit/observability/test_retriever_otel.py b/surfsense_backend/tests/unit/observability/test_retriever_otel.py new file mode 100644 index 000000000..9712a3150 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/test_retriever_otel.py @@ -0,0 +1,61 @@ +"""Tests for retriever OTel wrappers.""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any + +import pytest + +from app.retriever.documents_hybrid_search import _instrument_search + +pytestmark = pytest.mark.unit + + +class _Span: + def __init__(self) -> None: + self.attrs: dict[str, Any] = {} + + def set_attribute(self, key: str, value: Any) -> None: + self.attrs[key] = value + + +@contextmanager +def _fake_span(**kwargs): + span = _Span() + span.attrs.update(kwargs) + yield span + + +@pytest.mark.asyncio +async def test_retriever_wrapper_records_one_span_and_metric(monkeypatch) -> None: + calls: list[dict[str, Any]] = [] + + monkeypatch.setattr( + "app.retriever.documents_hybrid_search.ot.kb_search_span", + lambda **kwargs: _fake_span(**kwargs), + ) + monkeypatch.setattr( + "app.retriever.documents_hybrid_search.ot_metrics.record_kb_search_duration", + lambda duration_ms, **attrs: calls.append( + {"duration_ms": duration_ms, **attrs} + ), + ) + + class Retriever: + @_instrument_search("hybrid") + async def search( + self, + query_text: str, + top_k: int, + search_space_id: int, + ) -> list[str]: + del query_text, top_k, search_space_id + return ["doc-1", "doc-2"] + + result = await Retriever().search("hello", 3, 42) + + assert result == ["doc-1", "doc-2"] + assert len(calls) == 1 + assert calls[0]["search_space_id"] == 42 + assert calls[0]["surface"] == "documents" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py new file mode 100644 index 000000000..e014bb911 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_parallel_refactor_parity.py @@ -0,0 +1,557 @@ +"""Parity gate for the parallel refactor of ``stream_new_chat.py``. + +The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with +the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over +atomically. This file pins externally-observable behaviour at module +boundaries so a divergence between the two trees fails loudly *before* the +cutover. + +What we verify: + + 1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from + the new tree have the same call signature as the originals. + 2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the + same outputs as the inline code in the legacy file for representative + inputs (initial thinking step, image-capability gate, runtime context, + SSE frame sequences, token-usage frame shape, persistence guards). + 3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` / + ``can_recover_provider_rate_limit`` exist and are addressable. + +Delete this file along with ``stream_new_chat.py`` once the cutover is done +(see the parent refactor plan). +""" + +from __future__ import annotations + +import asyncio +import inspect +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from app.agents.new_chat.context import SurfSenseContextSchema +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.stream_new_chat import ( + stream_new_chat as old_stream_new_chat, + stream_resume_chat as old_stream_resume_chat, +) +from app.tasks.chat.streaming.flows import ( + stream_new_chat as new_stream_new_chat, + stream_resume_chat as new_stream_resume_chat, +) +from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import ( + build_initial_thinking_step, +) +from app.tasks.chat.streaming.flows.new_chat.llm_capability import ( + check_image_input_capability, +) +from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import ( + await_persist_task, + spawn_persist_assistant_shell_task, + spawn_persist_user_task, + spawn_set_ai_responding_bg, +) +from app.tasks.chat.streaming.flows.new_chat.runtime_context import ( + build_new_chat_runtime_context, +) +from app.tasks.chat.streaming.flows.resume_chat.runtime_context import ( + build_resume_chat_runtime_context, +) +from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame +from app.tasks.chat.streaming.flows.shared.first_frames import ( + iter_final_frames, + iter_initial_frames, +) +from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle +from app.tasks.chat.streaming.flows.shared.premium_quota import ( + PremiumReservation, + needs_premium_quota, +) +from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( + can_recover_provider_rate_limit, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------- signature + + +def _normalize_annotation(ann: Any) -> str: + """Compare-friendly form for an annotation. + + The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import + annotations``, so its annotations are evaluated at import time and come + back as type objects / typing generics. The new tree DOES use it, so its + annotations are PEP-563 strings. + + Both reprs describe the same types — strip the module prefixes / typing + namespace + the ```` wrapper so we compare the canonical + declared form. + """ + if ann is inspect.Signature.empty: + return "" + raw = ann if isinstance(ann, str) else repr(ann) + cleaned = ( + raw.replace("typing.", "") + .replace("collections.abc.", "") + .replace("app.db.", "") + .replace("app.agents.new_chat.filesystem_selection.", "") + .replace("app.agents.new_chat.context.", "") + ) + # Unwrap ```` → ``int`` (legacy-side type objects). + if cleaned.startswith(""): + cleaned = cleaned[len("")] + return cleaned + + +def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]: + return [ + (p.name, p.default, _normalize_annotation(p.annotation)) + for p in sig.parameters.values() + ] + + +def test_stream_new_chat_signature_matches_legacy() -> None: + old = inspect.signature(old_stream_new_chat) + new = inspect.signature(new_stream_new_chat) + assert _normalize_sig(new) == _normalize_sig(old) + assert _normalize_annotation(new.return_annotation) == _normalize_annotation( + old.return_annotation + ) + + +def test_stream_resume_chat_signature_matches_legacy() -> None: + old = inspect.signature(old_stream_resume_chat) + new = inspect.signature(new_stream_resume_chat) + assert _normalize_sig(new) == _normalize_sig(old) + assert _normalize_annotation(new.return_annotation) == _normalize_annotation( + old.return_annotation + ) + + +def test_orchestrators_are_async_generator_functions() -> None: + assert inspect.isasyncgenfunction(new_stream_new_chat) + assert inspect.isasyncgenfunction(new_stream_resume_chat) + + +# ------------------------------------------------------------ initial thinking + + +@pytest.mark.parametrize( + "user_query, image_urls, expected_title, expected_action", + [ + ("hello world", None, "Understanding your request", "Processing"), + ( + "", + ["data:image/png;base64,AAA"], + "Understanding your request", + "Processing", + ), + ("", None, "Understanding your request", "Processing"), + ], +) +def test_initial_thinking_step_branches( + user_query: str, + image_urls: list[str] | None, + expected_title: str, + expected_action: str, +) -> None: + step = build_initial_thinking_step( + user_query=user_query, + user_image_data_urls=image_urls, + ) + assert step.step_id == "thinking-1" + assert step.title == expected_title + assert len(step.items) == 1 + assert step.items[0].startswith(f"{expected_action}: ") + + +def test_initial_thinking_step_truncates_long_query() -> None: + long_query = "x" * 200 + step = build_initial_thinking_step( + user_query=long_query, + user_image_data_urls=None, + ) + # 80-char truncation + ellipsis, sandwiched after "Processing: ". + assert "..." in step.items[0] + item = step.items[0] + payload = item[len("Processing: ") :] + assert payload.startswith("x" * 80) and payload.endswith("...") + + +# ------------------------------------------------------------ capability gate + + +def test_image_capability_passes_without_images() -> None: + assert ( + check_image_input_capability(user_image_data_urls=None, agent_config=None) + is None + ) + + +def test_image_capability_passes_when_capability_unknown() -> None: + """Unknown / unmapped models are not blocked — only models LiteLLM has + *explicitly* marked text-only trip the gate.""" + + class _AgentConfig: + provider = "openrouter" + model_name = "unknown-mystery-model" + custom_provider = None + config_name = "Unknown" + litellm_params: dict[str, Any] = {} + + with patch( + "app.services.provider_capabilities.is_known_text_only_chat_model", + return_value=False, + ): + assert ( + check_image_input_capability( + user_image_data_urls=["data:image/png;base64,AAA"], + agent_config=_AgentConfig(), # type: ignore[arg-type] + ) + is None + ) + + +def test_image_capability_blocks_known_text_only_models() -> None: + class _AgentConfig: + provider = "openai" + model_name = "gpt-3.5-turbo" + custom_provider = None + config_name = "GPT-3.5" + litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"} + + with patch( + "app.services.provider_capabilities.is_known_text_only_chat_model", + return_value=True, + ): + result = check_image_input_capability( + user_image_data_urls=["data:image/png;base64,AAA"], + agent_config=_AgentConfig(), # type: ignore[arg-type] + ) + assert result is not None + message, error_code = result + assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + assert "GPT-3.5" in message + + +# ---------------------------------------------------------------- runtime ctx + + +def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None: + ctx = build_new_chat_runtime_context( + search_space_id=7, + mentioned_document_ids=[1, 2], + accepted_folder_ids=[10], + mentioned_folder_ids=[20, 30], + request_id="req", + turn_id="t1", + ) + assert isinstance(ctx, SurfSenseContextSchema) + assert ctx.search_space_id == 7 + assert list(ctx.mentioned_document_ids) == [1, 2] + assert list(ctx.mentioned_folder_ids) == [10] + assert ctx.request_id == "req" + assert ctx.turn_id == "t1" + + +def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None: + ctx = build_new_chat_runtime_context( + search_space_id=7, + mentioned_document_ids=None, + accepted_folder_ids=[], + mentioned_folder_ids=[20, 30], + request_id=None, + turn_id="t2", + ) + assert list(ctx.mentioned_folder_ids) == [20, 30] + + +def test_resume_chat_runtime_context_empty_mention_lists() -> None: + ctx = build_resume_chat_runtime_context( + search_space_id=42, request_id="req-r", turn_id="t-r" + ) + assert ctx.search_space_id == 42 + assert ctx.request_id == "req-r" + assert ctx.turn_id == "t-r" + + +# ---------------------------------------------------------------- SSE frames + + +def test_iter_initial_frames_emits_canonical_sequence() -> None: + svc = VercelStreamingService() + frames = list(iter_initial_frames(svc, turn_id="42:1700000000000")) + # Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy). + assert len(frames) == 4 + assert "42:1700000000000" in frames[2] + assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3] + + +def test_iter_final_frames_emits_idle_then_finish_done() -> None: + svc = VercelStreamingService() + frames = list(iter_final_frames(svc)) + assert len(frames) == 4 + assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0] + + +# ----------------------------------------------------------- token usage frame + + +class _FakeAccumulator: + """Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads.""" + + def __init__(self, summary: Any = None) -> None: + self._summary = summary + self.calls = [1, 2, 3] + self.grand_total = 100 + self.total_cost_micros = 50_000 + self.total_prompt_tokens = 60 + self.total_completion_tokens = 40 + + def per_message_summary(self) -> Any: + return self._summary + + def serialized_calls(self) -> list[Any]: + return list(self.calls) + + +def test_token_usage_frame_skipped_when_no_summary() -> None: + svc = VercelStreamingService() + frames = list( + iter_token_usage_frame( + svc, + accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type] + log_label="parity-empty", + ) + ) + assert frames == [] + + +def test_token_usage_frame_emitted_when_summary_present() -> None: + svc = VercelStreamingService() + frames = list( + iter_token_usage_frame( + svc, + accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type] + log_label="parity-populated", + ) + ) + assert len(frames) == 1 + # Field shape on the wire is fixed by the FE; assert each surfaces. + payload = frames[0] + for key in ( + '"prompt_tokens":60', + '"completion_tokens":40', + '"total_tokens":100', + '"cost_micros":50000', + ): + assert key in payload.replace(" ", "") + + +# ------------------------------------------------------------------ llm_bundle + + +def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None: + async def _run() -> tuple[Any, Any, str | None]: + with ( + patch( + "app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id", + return_value=None, + ), + ): + return await load_llm_bundle( + session=AsyncMock(), # type: ignore[arg-type] + config_id=-1, + search_space_id=7, + ) + + llm, agent_config, error = asyncio.run(_run()) + assert llm is None + assert agent_config is None + assert error is not None and "id -1" in error + + +def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None: + async def _run() -> tuple[Any, Any, str | None]: + with ( + patch( + "app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config", + new=AsyncMock(return_value=None), + ), + ): + return await load_llm_bundle( + session=AsyncMock(), # type: ignore[arg-type] + config_id=12, + search_space_id=7, + ) + + llm, agent_config, error = asyncio.run(_run()) + assert llm is None + assert agent_config is None + assert error is not None and "id 12" in error + + +# ----------------------------------------------------------------- premium quota + + +def test_needs_premium_quota_requires_user_and_premium_flag() -> None: + class _AgentConfig: + is_premium = True + + class _NonPremium: + is_premium = False + + assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type] + assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type] + assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type] + assert needs_premium_quota(None, "user-1") is False + + +def test_premium_reservation_dataclass_shape() -> None: + # Sanity: the dataclass exists and carries the fields the orchestrator uses. + r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True) + assert r.request_id == "abc" + assert r.reserved_micros == 100 + assert r.allowed is True + + +# ----------------------------------------------------------- rate-limit guard + + +@pytest.mark.parametrize( + "first_event_seen, recovered, requested_id, current_id, expected", + [ + (False, False, 0, -1, True), + # Already recovered: no second pass. + (False, True, 0, -1, False), + # User explicitly picked a config: don't silently switch. + (False, False, 5, -1, False), + # Already on a database-backed (positive) id. + (False, False, 0, 7, False), + # User has already seen output: silent rebuild not possible. + (True, False, 0, -1, False), + ], +) +def test_can_recover_provider_rate_limit_truth_table( + first_event_seen: bool, + recovered: bool, + requested_id: int, + current_id: int, + expected: bool, +) -> None: + # Use a known rate-limit-shaped exception so the helper's last condition + # is satisfied; the guard only short-circuits to False when one of the + # *other* preconditions fails. + exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}') + assert ( + can_recover_provider_rate_limit( + exc, + first_event_seen=first_event_seen, + runtime_rate_limit_recovered=recovered, + requested_llm_config_id=requested_id, + current_llm_config_id=current_id, + ) + is expected + ) + + +def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None: + assert ( + can_recover_provider_rate_limit( + ValueError("not a rate limit"), + first_event_seen=False, + runtime_rate_limit_recovered=False, + requested_llm_config_id=0, + current_llm_config_id=-1, + ) + is False + ) + + +# --------------------------------------------------------- persistence spawn + + +def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None: + async def _run() -> set[asyncio.Task]: + background: set[asyncio.Task] = set() + spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background) + return background + + bg = asyncio.run(_run()) + assert bg == set() + + +def test_spawn_persist_user_task_registers_and_self_unregisters() -> None: + async def _run() -> tuple[int, int]: + background: set[asyncio.Task] = set() + with patch( + "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn", + new=AsyncMock(return_value=99), + ): + task = spawn_persist_user_task( + chat_id=1, + user_id="u", + turn_id="t", + user_query="hi", + user_image_data_urls=None, + mentioned_documents=None, + background_tasks=background, + ) + size_before_await = len(background) + result = await asyncio.shield(task) + # Give the done-callback one event-loop tick to run. + await asyncio.sleep(0) + return size_before_await, result # type: ignore[return-value] + + size_before, result = asyncio.run(_run()) + assert size_before == 1 + assert result == 99 + + +def test_spawn_persist_assistant_shell_task_registers() -> None: + async def _run() -> int | None: + background: set[asyncio.Task] = set() + with patch( + "app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell", + new=AsyncMock(return_value=42), + ): + task = spawn_persist_assistant_shell_task( + chat_id=1, + user_id="u", + turn_id="t", + background_tasks=background, + ) + return await asyncio.shield(task) + + assert asyncio.run(_run()) == 42 + + +def test_await_persist_task_returns_none_on_failure() -> None: + async def _run() -> int | None: + async def _boom() -> int: + raise RuntimeError("DB down") + + task = asyncio.create_task(_boom()) + return await await_persist_task( + task, + chat_id=1, + turn_id="t", + log_label="parity-failure", + ) + + assert asyncio.run(_run()) is None + + +def test_await_persist_task_returns_none_for_none_input() -> None: + async def _run() -> int | None: + return await await_persist_task( + None, + chat_id=1, + turn_id="t", + log_label="parity-none", + ) + + assert asyncio.run(_run()) is None diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 953aebbef..eae54b1d4 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -313,6 +313,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149 }, ] +[[package]] +name = "asgiref" +version = "3.11.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/63/40/f03da1264ae8f7cfdbf9146542e5e7e8100a4c66ab48e791df9a03d3f6c0/asgiref-3.11.1.tar.gz", hash = "sha256:5f184dc43b7e763efe848065441eac62229c9f7b0475f41f80e207a114eda4ce", size = 38550 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/0a/a72d10ed65068e115044937873362e6e32fab1b7dce0046aeb224682c989/asgiref-3.11.1-py3-none-any.whl", hash = "sha256:e8667a091e69529631969fd45dc268fa79b99c92c5fcdda727757e52146ec133", size = 24345 }, +] + [[package]] name = "asyncpg" version = "0.31.0" @@ -1256,6 +1265,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/ca/6a667ccbe649856dcd3458bab80b016681b274399d6211187c6ab969fc50/courlan-1.3.2-py3-none-any.whl", hash = "sha256:d0dab52cf5b5b1000ee2839fbc2837e93b2514d3cb5bb61ae158a55b7a04c6be", size = 33848 }, ] +[[package]] +name = "croniter" +version = "6.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/de/5832661ed55107b8a09af3f0a2e71e0957226a59eb1dcf0a445cce6daf20/croniter-6.2.2.tar.gz", hash = "sha256:ba60832a5ec8e12e51b8691c3309a113d1cf6526bdf1a48150ce8ec7a532d0ab", size = 113762 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/39/783980e78cb92c2d7bdb1fc7dbc86e94ccc6d58224d76a7f1f51b6c51e30/croniter-6.2.2-py3-none-any.whl", hash = "sha256:a5d17b1060974d36251ea4faf388233eca8acf0d09cbd92d35f4c4ac8f279960", size = 45422 }, +] + [[package]] name = "cryptography" version = "46.0.6" @@ -5184,6 +5205,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676 }, ] +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/37/b6708e0eff5c5fb9aba2e0ea09f7f3bcbfd12a592d2a780241b5f6014df7/opentelemetry_exporter_otlp-1.40.0.tar.gz", hash = "sha256:7caa0870b95e2fcb59d64e16e2b639ecffb07771b6cd0000b5d12e5e4fef765a", size = 6152 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/fc/aea77c28d9f3ffef2fdafdc3f4a235aee4091d262ddabd25882f47ce5c5f/opentelemetry_exporter_otlp-1.40.0-py3-none-any.whl", hash = "sha256:48c87e539ec9afb30dc443775a1334cc5487de2f72a770a4c00b1610bf6c697d", size = 7023 }, +] + [[package]] name = "opentelemetry-exporter-otlp-proto-common" version = "1.40.0" @@ -5263,6 +5297,141 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/f3/1edc42716521a3f754ac32ffb908f102e0f131f8e43fcd9ab29cab286723/opentelemetry_instrumentation_aiohttp_client-0.61b0-py3-none-any.whl", hash = "sha256:09bc47514c162507b357366ce15578743fd6305078cf7d872db1c99c13fa6972", size = 14534 }, ] +[[package]] +name = "opentelemetry-instrumentation-asgi" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asgiref" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/3e/143cf5c034e58037307e6a24f06e0dd64b2c49ae60a965fc580027581931/opentelemetry_instrumentation_asgi-0.61b0.tar.gz", hash = "sha256:9d08e127244361dc33976d39dd4ca8f128b5aa5a7ae425208400a80a095019b5", size = 26691 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/78/154470cf9d741a7487fbb5067357b87386475bbb77948a6707cae982e158/opentelemetry_instrumentation_asgi-0.61b0-py3-none-any.whl", hash = "sha256:e4b3ce6b66074e525e717efff20745434e5efd5d9df6557710856fba356da7a4", size = 16980 }, +] + +[[package]] +name = "opentelemetry-instrumentation-celery" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8d/43/e79108a804d16b1dc8ff28edd0e94ac393cf6359a5adcd7cdd2ec4be85f4/opentelemetry_instrumentation_celery-0.61b0.tar.gz", hash = "sha256:0e352a567dc89ed8bc083fc635035ce3c5b96bbbd92831ffd676e93b87f8e94f", size = 14780 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/ed/c05f3c84b455654eb6c047474ffde61ed92efc24030f64213c98bca9d44b/opentelemetry_instrumentation_celery-0.61b0-py3-none-any.whl", hash = "sha256:01235733ff0cdf571cb03b270645abb14b9c8d830313dc5842097ec90146320b", size = 13856 }, +] + +[[package]] +name = "opentelemetry-instrumentation-dbapi" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/ed/ba91c9e4a3ec65781e9c59982109f0a36de9fa574f622596b33d1985dab5/opentelemetry_instrumentation_dbapi-0.61b0.tar.gz", hash = "sha256:02fa800682c1de87dcad0e59f2092b3b6fb8b8ea0636518f989e1166b418dcb9", size = 16761 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/a5/d26c68f3fd33eb7410985cef7700bb426e2c4a26de9207902cbbffb19a3f/opentelemetry_instrumentation_dbapi-0.61b0-py3-none-any.whl", hash = "sha256:8f762c39c8edd20c6aef3282550a2cfbfec76c3f431bf5c36327dcf9ece2e5a0", size = 14134 }, +] + +[[package]] +name = "opentelemetry-instrumentation-fastapi" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-instrumentation-asgi" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/35/aa727bb6e6ef930dcdc96a617b83748fece57b43c47d83ba8d83fbeca657/opentelemetry_instrumentation_fastapi-0.61b0.tar.gz", hash = "sha256:3a24f35b07c557ae1bbc483bf8412221f25d79a405f8b047de8b670722e2fa9f", size = 24800 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/05/acfeb2cccd434242a0a7d0ea29afaf077e04b42b35b485d89aee4e0d9340/opentelemetry_instrumentation_fastapi-0.61b0-py3-none-any.whl", hash = "sha256:a1a844d846540d687d377516b2ff698b51d87c781b59f47c214359c4a241047c", size = 13485 }, +] + +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/2a/e2becd55e33c29d1d9ef76e2579040ed1951cb33bacba259f6aff2fdd2a6/opentelemetry_instrumentation_httpx-0.61b0.tar.gz", hash = "sha256:6569ec097946c5551c2a4252f74c98666addd1bf047c1dde6b4ef426719ff8dd", size = 24104 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/88/dde310dce56e2d85cf1a09507f5888544955309edc4b8d22971d6d3d1417/opentelemetry_instrumentation_httpx-0.61b0-py3-none-any.whl", hash = "sha256:dee05c93a6593a5dc3ae5d9d5c01df8b4e2c5d02e49275e5558534ee46343d5e", size = 17198 }, +] + +[[package]] +name = "opentelemetry-instrumentation-logging" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/e0/69473f925acfe2d4edf5c23bcced36906ac3627aa7c5722a8e3f60825f3b/opentelemetry_instrumentation_logging-0.61b0.tar.gz", hash = "sha256:feaa30b700acd2a37cc81db5f562ab0c3a5b6cc2453595e98b72c01dcf649584", size = 17906 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/0e/2137db5239cc5e564495549a4d11488a7af9b48fc76520a0eea20e69ddae/opentelemetry_instrumentation_logging-0.61b0-py3-none-any.whl", hash = "sha256:6d87e5ded6a0128d775d41511f8380910a1b610671081d16efb05ac3711c0074", size = 17076 }, +] + +[[package]] +name = "opentelemetry-instrumentation-psycopg" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-instrumentation-dbapi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/00/b98148b3054eb8301a56d523de82ee2fd86a047dba38330c2404d85496e3/opentelemetry_instrumentation_psycopg-0.61b0.tar.gz", hash = "sha256:74e9fed3802945f7ae335cffc30fd18cf58c34a4d0619315f799fa21eb5c74ff", size = 11907 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/2b/3c36bfc6dc82a7c50c769aff407eaf32e688d655bc61a774609d96b55603/opentelemetry_instrumentation_psycopg-0.61b0-py3-none-any.whl", hash = "sha256:a3e242cad56c0ad4f4f872017c73ce7e6c7012081dda6bd0d776c127fedc358a", size = 11662 }, +] + +[[package]] +name = "opentelemetry-instrumentation-redis" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/21/26205f89358a5f2be3ee5512d3d3bce16b622977f64aeaa9d3fa8887dd39/opentelemetry_instrumentation_redis-0.61b0.tar.gz", hash = "sha256:ae0fbb56be9a641e621d55b02a7d62977a2c77c5ee760addd79b9b266e46e523", size = 14781 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/e1/8f4c8e4194291dbe828aeabe779050a8497b379ad90040a5a0a7074b1d08/opentelemetry_instrumentation_redis-0.61b0-py3-none-any.whl", hash = "sha256:8d4e850bbb5f8eeafa44c0eac3a007990c7125de187bc9c3659e29ff7e091172", size = 15506 }, +] + +[[package]] +name = "opentelemetry-instrumentation-sqlalchemy" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "packaging" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/4f/3a325b180944610697a0a926d49d782b41a86120050d44fefb2715b630ac/opentelemetry_instrumentation_sqlalchemy-0.61b0.tar.gz", hash = "sha256:13a3a159a2043a52f0180b3757fbaa26741b0e08abb50deddce4394c118956e6", size = 15343 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/97/b906a930c6a1a20c53ecc8b58cabc2cdd0ce560a2b5d44259084ffe4333e/opentelemetry_instrumentation_sqlalchemy-0.61b0-py3-none-any.whl", hash = "sha256:f115e0be54116ba4c327b8d7b68db4045ee18d44439d888ab8130a549c50d1c1", size = 14547 }, +] + [[package]] name = "opentelemetry-proto" version = "1.40.0" @@ -7947,7 +8116,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.25" +version = "0.0.26" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -7958,6 +8127,7 @@ dependencies = [ { name = "celery", extra = ["redis"] }, { name = "chonkie", extra = ["all"] }, { name = "composio" }, + { name = "croniter" }, { name = "datasets" }, { name = "daytona" }, { name = "deepagents" }, @@ -7993,6 +8163,17 @@ dependencies = [ { name = "notion-client" }, { name = "notion-markdown" }, { name = "numpy" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp" }, + { name = "opentelemetry-instrumentation-celery" }, + { name = "opentelemetry-instrumentation-fastapi" }, + { name = "opentelemetry-instrumentation-httpx" }, + { name = "opentelemetry-instrumentation-logging" }, + { name = "opentelemetry-instrumentation-psycopg" }, + { name = "opentelemetry-instrumentation-redis" }, + { name = "opentelemetry-instrumentation-sqlalchemy" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, { name = "pgvector" }, { name = "playwright" }, { name = "psycopg", extra = ["binary", "pool"] }, @@ -8043,6 +8224,7 @@ requires-dist = [ { name = "celery", extras = ["redis"], specifier = ">=5.5.3" }, { name = "chonkie", extras = ["all"], specifier = ">=1.5.0" }, { name = "composio", specifier = ">=0.10.9" }, + { name = "croniter", specifier = ">=2.0.0" }, { name = "datasets", specifier = ">=2.21.0" }, { name = "daytona", specifier = ">=0.146.0" }, { name = "deepagents", specifier = ">=0.4.12,<0.5" }, @@ -8078,6 +8260,17 @@ requires-dist = [ { name = "notion-client", specifier = ">=2.3.0" }, { name = "notion-markdown", specifier = ">=0.7.0" }, { name = "numpy", specifier = ">=1.24.0" }, + { name = "opentelemetry-api", specifier = ">=1.40.0" }, + { name = "opentelemetry-exporter-otlp", specifier = ">=1.40.0" }, + { name = "opentelemetry-instrumentation-celery", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-logging", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-psycopg", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = ">=0.61b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.61b0" }, + { name = "opentelemetry-sdk", specifier = ">=1.40.0" }, + { name = "opentelemetry-semantic-conventions", specifier = ">=0.61b0" }, { name = "pgvector", specifier = ">=0.3.6" }, { name = "playwright", specifier = ">=1.50.0" }, { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.2" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 2f17899a8..13cd31b80 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.25", + "version": "0.0.26", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/.env.example b/surfsense_desktop/.env.example index e127b99e0..f4e797250 100644 --- a/surfsense_desktop/.env.example +++ b/surfsense_desktop/.env.example @@ -3,7 +3,12 @@ # The hosted web frontend URL. Used to intercept OAuth redirects and keep them # inside the desktop app. Set to your production frontend domain. -HOSTED_FRONTEND_URL=https://surfsense.net +HOSTED_FRONTEND_URL=https://surfsense.com + +# Runtime override for the above (read at app start, no rebuild required). +# Useful for self-hosters whose backend NEXT_FRONTEND_URL differs from the +# value baked into the official desktop builds. Leave empty to use HOSTED_FRONTEND_URL. +# SURFSENSE_HOSTED_FRONTEND_URL_OVERRIDE= # PostHog analytics (leave empty to disable) POSTHOG_KEY= diff --git a/surfsense_desktop/assets/icons/1024x1024.png b/surfsense_desktop/assets/icons/1024x1024.png new file mode 100644 index 000000000..853201c5e Binary files /dev/null and b/surfsense_desktop/assets/icons/1024x1024.png differ diff --git a/surfsense_desktop/assets/icons/128x128.png b/surfsense_desktop/assets/icons/128x128.png new file mode 100644 index 000000000..97286c8b6 Binary files /dev/null and b/surfsense_desktop/assets/icons/128x128.png differ diff --git a/surfsense_desktop/assets/icons/16x16.png b/surfsense_desktop/assets/icons/16x16.png new file mode 100644 index 000000000..860f9fef1 Binary files /dev/null and b/surfsense_desktop/assets/icons/16x16.png differ diff --git a/surfsense_desktop/assets/icons/256x256.png b/surfsense_desktop/assets/icons/256x256.png new file mode 100644 index 000000000..edb7aa512 Binary files /dev/null and b/surfsense_desktop/assets/icons/256x256.png differ diff --git a/surfsense_desktop/assets/icons/32x32.png b/surfsense_desktop/assets/icons/32x32.png new file mode 100644 index 000000000..2c1ef1222 Binary files /dev/null and b/surfsense_desktop/assets/icons/32x32.png differ diff --git a/surfsense_desktop/assets/icons/48x48.png b/surfsense_desktop/assets/icons/48x48.png new file mode 100644 index 000000000..2d765024d Binary files /dev/null and b/surfsense_desktop/assets/icons/48x48.png differ diff --git a/surfsense_desktop/assets/icons/512x512.png b/surfsense_desktop/assets/icons/512x512.png new file mode 100644 index 000000000..3fc480dd7 Binary files /dev/null and b/surfsense_desktop/assets/icons/512x512.png differ diff --git a/surfsense_desktop/assets/icons/64x64.png b/surfsense_desktop/assets/icons/64x64.png new file mode 100644 index 000000000..a218a4ee2 Binary files /dev/null and b/surfsense_desktop/assets/icons/64x64.png differ diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml index e4e7670ec..0a7c48203 100644 --- a/surfsense_desktop/electron-builder.yml +++ b/surfsense_desktop/electron-builder.yml @@ -55,6 +55,11 @@ mac: NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists." NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer." NSAppleEventsUsageDescription: "SurfSense uses Apple Events to interact with the active application." + # `surfsense://` scheme — install-time registration for LaunchServices. + CFBundleURLTypes: + - CFBundleURLName: com.surfsense.desktop + CFBundleURLSchemes: + - surfsense target: - target: dmg arch: [x64, arm64] @@ -72,7 +77,7 @@ nsis: createDesktopShortcut: true createStartMenuShortcut: true linux: - icon: assets/icon.png + icon: assets/icons/ category: Utility artifactName: "${productName}-${version}-${arch}.${ext}" mimeTypes: diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index 0ad279ece..1f0e6dafc 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,6 +1,7 @@ { "name": "surfsense-desktop", - "version": "0.0.25", + "productName": "SurfSense", + "version": "0.0.26", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_desktop/src/ipc/channels.ts b/surfsense_desktop/src/ipc/channels.ts index 8d2af5107..17daab9a6 100644 --- a/surfsense_desktop/src/ipc/channels.ts +++ b/surfsense_desktop/src/ipc/channels.ts @@ -2,6 +2,8 @@ export const IPC_CHANNELS = { OPEN_EXTERNAL: 'open-external', GET_APP_VERSION: 'get-app-version', DEEP_LINK: 'deep-link', + UPDATE_DOWNLOADED: 'update:downloaded', + UPDATE_INSTALL_NOW: 'update:install-now', QUICK_ASK_TEXT: 'quick-ask-text', SET_QUICK_ASK_MODE: 'set-quick-ask-mode', GET_QUICK_ASK_MODE: 'get-quick-ask-mode', diff --git a/surfsense_desktop/src/ipc/handlers.ts b/surfsense_desktop/src/ipc/handlers.ts index d918fd90d..ed7eaac66 100644 --- a/surfsense_desktop/src/ipc/handlers.ts +++ b/surfsense_desktop/src/ipc/handlers.ts @@ -51,6 +51,7 @@ import { stopAgentFilesystemTreeWatch, type AgentFilesystemTreeWatchOptions, } from '../modules/agent-filesystem-tree-watcher'; +import { installDownloadedUpdate } from '../modules/auto-updater'; let authTokens: { bearer: string; refresh: string } | null = null; @@ -70,6 +71,10 @@ export function registerIpcHandlers(): void { return app.getVersion(); }); + ipcMain.handle(IPC_CHANNELS.UPDATE_INSTALL_NOW, () => { + installDownloadedUpdate(); + }); + ipcMain.handle(IPC_CHANNELS.GET_PERMISSIONS_STATUS, () => { return getPermissionsStatus(); }); diff --git a/surfsense_desktop/src/main.ts b/surfsense_desktop/src/main.ts index 492c61f17..632758ba8 100644 --- a/surfsense_desktop/src/main.ts +++ b/surfsense_desktop/src/main.ts @@ -1,7 +1,7 @@ import { app } from 'electron'; import { registerGlobalErrorHandlers, showErrorDialog } from './modules/errors'; -import { startNextServer } from './modules/server'; +import { startNextServer, stopNextServer } from './modules/server'; import { createMainWindow, getMainWindow, markQuitting } from './modules/window'; import { setupDeepLinks, handlePendingDeepLink, hasPendingDeepLink } from './modules/deep-links'; import { setupAutoUpdater } from './modules/auto-updater'; @@ -19,6 +19,7 @@ import { } from './modules/auto-launch'; registerGlobalErrorHandlers(); +app.setName('SurfSense'); if (!setupDeepLinks()) { app.quit(); @@ -93,6 +94,7 @@ app.on('will-quit', async (e) => { e.preventDefault(); unregisterQuickAsk(); unregisterFolderWatcher(); + stopNextServer(); destroyTray(); await shutdownAnalytics(); app.exit(); diff --git a/surfsense_desktop/src/modules/auto-updater.ts b/surfsense_desktop/src/modules/auto-updater.ts index e323abe53..b318b737d 100644 --- a/surfsense_desktop/src/modules/auto-updater.ts +++ b/surfsense_desktop/src/modules/auto-updater.ts @@ -1,57 +1,201 @@ -import { app, dialog } from 'electron'; +import { app, BrowserWindow, dialog } from 'electron'; +import { IPC_CHANNELS } from '../ipc/channels'; import { trackEvent } from './analytics'; const SEMVER_RE = /^\d+\.\d+\.\d+/; -export function setupAutoUpdater(): void { - if (!app.isPackaged) return; +type AutoUpdater = { + autoDownload: boolean; + on(event: string, listener: (...args: any[]) => void): void; + once(event: string, listener: (...args: any[]) => void): void; + removeListener(event: string, listener: (...args: any[]) => void): void; + checkForUpdates(): Promise; + quitAndInstall(): void; +}; - const version = app.getVersion(); - if (!SEMVER_RE.test(version)) { - console.log(`Auto-updater: skipping — "${version}" is not valid semver`); - return; +type UpdateInfo = { + version: string; +}; + +type UpdateMenuState = + | { status: 'idle' } + | { status: 'downloading'; version: string } + | { status: 'ready'; version: string }; + +let listenersRegistered = false; +let updateMenuState: UpdateMenuState = { status: 'idle' }; +const updateMenuStateListeners = new Set<(state: UpdateMenuState) => void>(); + +export function getUpdateMenuState(): UpdateMenuState { + return updateMenuState; +} + +export function onUpdateMenuStateChange(listener: (state: UpdateMenuState) => void): () => void { + updateMenuStateListeners.add(listener); + return () => { + updateMenuStateListeners.delete(listener); + }; +} + +function setUpdateMenuState(state: UpdateMenuState): void { + updateMenuState = state; + for (const listener of updateMenuStateListeners) { + listener(state); } +} +function getAutoUpdater(): AutoUpdater { const { autoUpdater } = require('electron-updater'); + return autoUpdater as AutoUpdater; +} +function configureAutoUpdater(autoUpdater: AutoUpdater): void { autoUpdater.autoDownload = true; - autoUpdater.on('update-available', (info: { version: string }) => { + if (listenersRegistered) return; + listenersRegistered = true; + + const version = app.getVersion(); + + autoUpdater.on('update-available', (info: UpdateInfo) => { console.log(`Update available: ${info.version}`); + setUpdateMenuState({ status: 'downloading', version: info.version }); trackEvent('desktop_update_available', { current_version: version, new_version: info.version, }); }); - autoUpdater.on('update-downloaded', (info: { version: string }) => { + autoUpdater.on('update-downloaded', (info: UpdateInfo) => { console.log(`Update downloaded: ${info.version}`); + setUpdateMenuState({ status: 'ready', version: info.version }); trackEvent('desktop_update_downloaded', { current_version: version, new_version: info.version, }); - dialog.showMessageBox({ - type: 'info', - buttons: ['Restart', 'Later'], - defaultId: 0, - title: 'Update Ready', - message: `Version ${info.version} has been downloaded. Restart to apply the update.`, - }).then(({ response }: { response: number }) => { - if (response === 0) { - trackEvent('desktop_update_install_accepted', { new_version: info.version }); - autoUpdater.quitAndInstall(); - } else { - trackEvent('desktop_update_install_deferred', { new_version: info.version }); - } - }); + notifyRenderersUpdateDownloaded(info); + }); + + autoUpdater.on('update-not-available', () => { + setUpdateMenuState({ status: 'idle' }); }); autoUpdater.on('error', (err: Error) => { + setUpdateMenuState({ status: 'idle' }); console.log('Auto-updater: update check skipped —', err.message?.split('\n')[0]); trackEvent('desktop_update_error', { message: err.message?.split('\n')[0], }); }); +} + +function notifyRenderersUpdateDownloaded(info: UpdateInfo): void { + for (const win of BrowserWindow.getAllWindows()) { + if (!win.isDestroyed()) { + win.webContents.send(IPC_CHANNELS.UPDATE_DOWNLOADED, { + version: info.version, + }); + } + } +} + +export function installDownloadedUpdate(): void { + const autoUpdater = getAutoUpdater(); + trackEvent('desktop_update_install_accepted', { source: 'renderer_prompt' }); + autoUpdater.quitAndInstall(); +} + +export function setupAutoUpdater(): void { + if (!app.isPackaged) return; + + const version = app.getVersion(); + if (!SEMVER_RE.test(version)) { + console.log(`Auto-updater: skipping - "${version}" is not valid semver`); + return; + } + + const autoUpdater = getAutoUpdater(); + configureAutoUpdater(autoUpdater); autoUpdater.checkForUpdates().catch(() => {}); } + +export async function checkForUpdatesManually(): Promise { + const currentState = getUpdateMenuState(); + if (currentState.status === 'ready') { + installDownloadedUpdate(); + return; + } + if (currentState.status === 'downloading') return; + + if (!app.isPackaged) { + await dialog.showMessageBox({ + type: 'info', + title: 'Updates Unavailable', + message: 'Updates are only available in packaged builds.', + }); + return; + } + + const version = app.getVersion(); + if (!SEMVER_RE.test(version)) { + await dialog.showMessageBox({ + type: 'info', + title: 'Updates Unavailable', + message: `Version "${version}" is not a valid release version, so updates cannot be checked.`, + }); + return; + } + + const autoUpdater = getAutoUpdater(); + configureAutoUpdater(autoUpdater); + + try { + const result = await new Promise<'not-available' | 'downloaded'>((resolve, reject) => { + const cleanup = () => { + autoUpdater.removeListener('update-available', onAvailable); + autoUpdater.removeListener('update-not-available', onNotAvailable); + autoUpdater.removeListener('update-downloaded', onDownloaded); + autoUpdater.removeListener('error', onError); + }; + const onAvailable = () => {}; + const onNotAvailable = () => { + cleanup(); + resolve('not-available'); + }; + const onDownloaded = () => { + cleanup(); + resolve('downloaded'); + }; + const onError = (err: Error) => { + cleanup(); + reject(err); + }; + + autoUpdater.once('update-available', onAvailable); + autoUpdater.once('update-not-available', onNotAvailable); + autoUpdater.once('update-downloaded', onDownloaded); + autoUpdater.once('error', onError); + autoUpdater.checkForUpdates().catch((err: Error) => { + cleanup(); + setUpdateMenuState({ status: 'idle' }); + reject(err); + }); + }); + + if (result === 'not-available') { + await dialog.showMessageBox({ + type: 'info', + title: 'No Updates Available', + message: "You're up to date.", + }); + } + } catch (err) { + setUpdateMenuState({ status: 'idle' }); + await dialog.showMessageBox({ + type: 'error', + title: 'Update Check Failed', + message: err instanceof Error ? err.message : String(err), + }); + } +} diff --git a/surfsense_desktop/src/modules/deep-links.ts b/surfsense_desktop/src/modules/deep-links.ts index 11b7bfcff..d4c0da467 100644 --- a/surfsense_desktop/src/modules/deep-links.ts +++ b/surfsense_desktop/src/modules/deep-links.ts @@ -1,7 +1,7 @@ import { app } from 'electron'; import path from 'path'; import { getMainWindow } from './window'; -import { getServerPort } from './server'; +import { getServerOrigin } from './server'; import { trackEvent } from './analytics'; const PROTOCOL = 'surfsense'; @@ -23,7 +23,7 @@ function handleDeepLink(url: string) { }); if (parsed.hostname === 'auth' && parsed.pathname === '/callback') { const params = parsed.searchParams.toString(); - win.loadURL(`http://localhost:${getServerPort()}/auth/callback?${params}`); + win.loadURL(`${getServerOrigin()}/auth/callback?${params}`); } win.show(); @@ -60,6 +60,11 @@ export function setupDeepLinks(): boolean { app.setAsDefaultProtocolClient(PROTOCOL); } + // Cold-start on Windows/Linux: protocol URL arrives via argv of the + // first instance, not via `second-instance` or `open-url`. + const cold = process.argv.find((arg) => arg.startsWith(`${PROTOCOL}://`)); + if (cold) handleDeepLink(cold); + return true; } diff --git a/surfsense_desktop/src/modules/menu.ts b/surfsense_desktop/src/modules/menu.ts index 128a73a21..629d88a04 100644 --- a/surfsense_desktop/src/modules/menu.ts +++ b/surfsense_desktop/src/modules/menu.ts @@ -1,13 +1,118 @@ -import { Menu } from 'electron'; +import { app, Menu, shell } from 'electron'; +import { + checkForUpdatesManually, + getUpdateMenuState, + installDownloadedUpdate, + onUpdateMenuStateChange, +} from './auto-updater'; + +let updateMenuListenerRegistered = false; + +function getUpdateMenuItem(): Electron.MenuItemConstructorOptions { + const state = getUpdateMenuState(); + + if (state.status === 'downloading') { + return { + label: 'Downloading...', + enabled: false, + }; + } + + if (state.status === 'ready') { + return { + label: 'Install and Restart', + click: () => { + installDownloadedUpdate(); + }, + }; + } + + return { + label: 'Check for Updates...', + click: () => { + void checkForUpdatesManually(); + }, + }; +} + +const privacyPolicyItem: Electron.MenuItemConstructorOptions = { + label: 'Privacy Policy', + click: () => { + void shell.openExternal('https://www.surfsense.com/privacy'); + }, +}; + +const termsOfServiceItem: Electron.MenuItemConstructorOptions = { + label: 'Terms of Service', + click: () => { + void shell.openExternal('https://www.surfsense.com/terms'); + }, +}; export function setupMenu(): void { + if (!updateMenuListenerRegistered) { + updateMenuListenerRegistered = true; + onUpdateMenuStateChange(() => { + setupMenu(); + }); + } + const isMac = process.platform === 'darwin'; + const isDev = !app.isPackaged; + const updateMenuItem = getUpdateMenuItem(); + const viewSubmenu: Electron.MenuItemConstructorOptions[] = [ + { role: 'reload' as const }, + { role: 'forceReload' as const }, + ...(isDev + ? [ + { role: 'toggleDevTools' as const }, + ] + : []), + { type: 'separator' as const }, + { role: 'resetZoom' as const }, + { role: 'zoomIn' as const }, + { role: 'zoomOut' as const }, + { type: 'separator' as const }, + { role: 'togglefullscreen' as const }, + ]; const template: Electron.MenuItemConstructorOptions[] = [ - ...(isMac ? [{ role: 'appMenu' as const }] : []), + ...(isMac + ? [{ + label: app.name, + submenu: [ + { role: 'about' as const }, + updateMenuItem, + { type: 'separator' as const }, + { role: 'services' as const }, + { type: 'separator' as const }, + { role: 'hide' as const }, + { role: 'hideOthers' as const }, + { role: 'unhide' as const }, + { type: 'separator' as const }, + { role: 'quit' as const }, + ], + }] + : []), { role: 'fileMenu' as const }, { role: 'editMenu' as const }, - { role: 'viewMenu' as const }, + { + label: 'View', + submenu: viewSubmenu, + }, { role: 'windowMenu' as const }, + { + role: 'help' as const, + submenu: [ + ...(!isMac + ? [ + updateMenuItem, + { type: 'separator' as const }, + ] + : []), + privacyPolicyItem, + termsOfServiceItem, + ], + }, ]; Menu.setApplicationMenu(Menu.buildFromTemplate(template)); } diff --git a/surfsense_desktop/src/modules/quick-ask.ts b/surfsense_desktop/src/modules/quick-ask.ts index b31ae1bcd..0807e2e08 100644 --- a/surfsense_desktop/src/modules/quick-ask.ts +++ b/surfsense_desktop/src/modules/quick-ask.ts @@ -1,8 +1,8 @@ -import { BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron'; +import { app, BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron'; import path from 'path'; import { IPC_CHANNELS } from '../ipc/channels'; import { checkAccessibilityPermission, getFrontmostApp, simulateCopy, simulatePaste } from './platform'; -import { getServerPort } from './server'; +import { getServerOrigin } from './server'; import { getShortcuts } from './shortcuts'; import { getActiveSearchSpaceId } from './active-search-space'; import { trackEvent } from './analytics'; @@ -51,6 +51,7 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow { contextIsolation: true, nodeIntegration: false, sandbox: true, + devTools: !app.isPackaged, }, show: false, skipTaskbar: true, @@ -58,7 +59,7 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow { const spaceId = pendingSearchSpaceId; const route = spaceId ? `/dashboard/${spaceId}/new-chat` : '/dashboard'; - quickAskWindow.loadURL(`http://localhost:${getServerPort()}${route}?quickAssist=true`); + quickAskWindow.loadURL(`${getServerOrigin()}${route}?quickAssist=true`); quickAskWindow.once('ready-to-show', () => { quickAskWindow?.show(); @@ -69,7 +70,7 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow { }); quickAskWindow.webContents.setWindowOpenHandler(({ url }) => { - if (url.startsWith('http://localhost')) { + if (url.startsWith(getServerOrigin())) { return { action: 'allow' }; } shell.openExternal(url); diff --git a/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts b/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts index fd771b0f7..0cfc92297 100644 --- a/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts +++ b/surfsense_desktop/src/modules/screen-capture/screen-region-picker.ts @@ -1,4 +1,4 @@ -import { BrowserWindow, desktopCapturer, nativeImage, screen } from 'electron'; +import { app, BrowserWindow, desktopCapturer, nativeImage, screen } from 'electron'; import path from 'path'; import { IPC_CHANNELS } from '../../ipc/channels'; function fitNativeImageToWorkArea(img: Electron.NativeImage, display: Electron.Display): Electron.NativeImage { @@ -261,6 +261,7 @@ export function pickScreenRegion(opts?: { windowDataUrl?: string }): Promise { contextIsolation: true, nodeIntegration: false, sandbox: true, + devTools: !app.isPackaged, }, }); diff --git a/surfsense_desktop/src/modules/server.ts b/surfsense_desktop/src/modules/server.ts index e2f078a8c..fc2fa05c3 100644 --- a/surfsense_desktop/src/modules/server.ts +++ b/surfsense_desktop/src/modules/server.ts @@ -1,14 +1,20 @@ import path from 'path'; -import { app } from 'electron'; +import { app, utilityProcess } from 'electron'; import { getPort } from 'get-port-please'; const isDev = !app.isPackaged; +const SERVER_HOST = '127.0.0.1'; let serverPort = 3000; +let nextServerProcess: ReturnType | null = null; export function getServerPort(): number { return serverPort; } +export function getServerOrigin(): string { + return `http://${SERVER_HOST}:${serverPort}`; +} + function getStandalonePath(): string { if (isDev) { return path.join(__dirname, '..', '..', 'surfsense_web', '.next', 'standalone', 'surfsense_web'); @@ -38,16 +44,55 @@ export async function startNextServer(): Promise { const standalonePath = getStandalonePath(); const serverScript = path.join(standalonePath, 'server.js'); - process.env.PORT = String(serverPort); - process.env.HOSTNAME = '0.0.0.0'; - process.env.NODE_ENV = 'production'; - process.chdir(standalonePath); + const child = utilityProcess.fork(serverScript, [], { + cwd: standalonePath, + env: { + ...process.env, + PORT: String(serverPort), + // Loopback bind: avoids 0.0.0.0 leaking into request.url and redirect origins. + HOSTNAME: SERVER_HOST, + NODE_ENV: 'production', + }, + serviceName: 'SurfSense Next Server', + stdio: 'pipe', + }); + nextServerProcess = child; - require(serverScript); + child.stdout?.on('data', (chunk) => { + process.stdout.write(chunk); + }); + child.stderr?.on('data', (chunk) => { + process.stderr.write(chunk); + }); - const ready = await waitForServer(`http://localhost:${serverPort}`); + const handleExit = (code: number) => { + if (nextServerProcess === child) { + nextServerProcess = null; + } + console.error(`Next.js server exited with code ${code}`); + }; + child.on('exit', handleExit); + + let startupExitHandler: ((code: number) => void) | null = null; + const exited = new Promise((_resolve, reject) => { + startupExitHandler = (code: number) => { + reject(new Error(`Next.js server exited before startup completed with code ${code}`)); + }; + child.once('exit', startupExitHandler); + }); + + const ready = await Promise.race([waitForServer(getServerOrigin()), exited]); + if (startupExitHandler) { + child.removeListener('exit', startupExitHandler); + } if (!ready) { + stopNextServer(); throw new Error('Next.js server failed to start within 30 s'); } console.log(`Next.js server ready on port ${serverPort}`); } + +export function stopNextServer(): void { + nextServerProcess?.kill(); + nextServerProcess = null; +} diff --git a/surfsense_desktop/src/modules/tray.ts b/surfsense_desktop/src/modules/tray.ts index f0221fe53..e71168f6e 100644 --- a/surfsense_desktop/src/modules/tray.ts +++ b/surfsense_desktop/src/modules/tray.ts @@ -10,6 +10,30 @@ let tray: Tray | null = null; let registeredGeneralAssist: string | null = null; let registeredScreenshotAssist: string | null = null; +function buildContextMenu(screenshotAccelerator: string): Menu { + return Menu.buildFromTemplate([ + { label: 'Open SurfSense', click: () => showMainWindow('tray_menu') }, + { + label: 'Take Screenshot\u2026', + accelerator: screenshotAccelerator || undefined, + click: () => { + trackEvent('desktop_tray_screenshot_clicked'); + void Promise.resolve(runScreenshotAssistShortcut()).catch((err) => { + console.error('[tray] Screenshot Assist failed:', err); + }); + }, + }, + { type: 'separator' }, + { + label: 'Quit', + click: () => { + trackEvent('desktop_tray_quit_clicked'); + app.exit(0); + }, + }, + ]); +} + function getTrayIcon(): NativeImage { const iconName = process.platform === 'darwin' @@ -59,22 +83,10 @@ export async function createTray(): Promise { tray = new Tray(getTrayIcon()); tray.setToolTip('SurfSense'); - const contextMenu = Menu.buildFromTemplate([ - { label: 'Open SurfSense', click: () => showMainWindow('tray_menu') }, - { type: 'separator' }, - { - label: 'Quit', - click: () => { - trackEvent('desktop_tray_quit_clicked'); - app.exit(0); - }, - }, - ]); - - tray.setContextMenu(contextMenu); + const shortcuts = await getShortcuts(); + tray.setContextMenu(buildContextMenu(shortcuts.screenshotAssist)); tray.on('double-click', () => showMainWindow('tray_click')); - const shortcuts = await getShortcuts(); registeredGeneralAssist = registerOne( null, shortcuts.generalAssist, @@ -107,6 +119,7 @@ export async function reregisterScreenshotAssist(): Promise { runScreenshotAssistShortcut, 'Screenshot Assist' ); + tray?.setContextMenu(buildContextMenu(shortcuts.screenshotAssist)); } export function destroyTray(): void { diff --git a/surfsense_desktop/src/modules/window.ts b/surfsense_desktop/src/modules/window.ts index 5317005d5..42011d089 100644 --- a/surfsense_desktop/src/modules/window.ts +++ b/surfsense_desktop/src/modules/window.ts @@ -2,12 +2,30 @@ import { app, BrowserWindow, shell, session } from 'electron'; import path from 'path'; import { trackEvent } from './analytics'; import { showErrorDialog } from './errors'; -import { getServerPort } from './server'; +import { getServerOrigin, getServerPort } from './server'; import { setActiveSearchSpaceId } from './active-search-space'; const isDev = !app.isPackaged; -const HOSTED_FRONTEND_URL = process.env.HOSTED_FRONTEND_URL as string; const isMac = process.platform === 'darwin'; +const WINDOW_TITLE = 'SurfSense'; + +function getHostedFrontendUrl(): string { + return ( + process.env.SURFSENSE_HOSTED_FRONTEND_URL_OVERRIDE || + process.env.HOSTED_FRONTEND_URL || + 'https://surfsense.com' + ); +} + +function getHostedFrontendHosts(): string[] { + try { + const host = new URL(getHostedFrontendUrl()).host; + const sibling = host.startsWith('www.') ? host.slice(4) : `www.${host}`; + return Array.from(new Set([host, sibling])); + } catch { + return []; + } +} let mainWindow: BrowserWindow | null = null; let isQuitting = false; @@ -24,6 +42,7 @@ export function markQuitting(): void { export function createMainWindow(initialPath = '/dashboard'): BrowserWindow { mainWindow = new BrowserWindow({ + title: WINDOW_TITLE, width: 1280, height: 800, minWidth: 800, @@ -34,6 +53,7 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow { nodeIntegration: false, sandbox: true, webviewTag: false, + devTools: !app.isPackaged, }, show: false, ...(isMac @@ -48,21 +68,66 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow { mainWindow?.show(); }); - mainWindow.loadURL(`http://localhost:${getServerPort()}${initialPath}`); + mainWindow.webContents.on('page-title-updated', (event) => { + event.preventDefault(); + mainWindow?.setTitle(WINDOW_TITLE); + }); + mainWindow.webContents.on('did-finish-load', () => { + mainWindow?.setTitle(WINDOW_TITLE); + }); + + mainWindow.loadURL(`${getServerOrigin()}${initialPath}`); mainWindow.webContents.setWindowOpenHandler(({ url }) => { - if (url.startsWith('http://localhost')) { + if (url.startsWith(getServerOrigin())) { return { action: 'allow' }; } shell.openExternal(url); return { action: 'deny' }; }); - const filter = { urls: [`${HOSTED_FRONTEND_URL}/*`] }; - session.defaultSession.webRequest.onBeforeRequest(filter, (details, callback) => { - const rewritten = details.url.replace(HOSTED_FRONTEND_URL, `http://localhost:${getServerPort()}`); - callback({ redirectURL: rewritten }); - }); + const hostedHosts = getHostedFrontendHosts(); + const rewriteFilter = { + urls: hostedHosts.flatMap((h) => [`http://${h}/*`, `https://${h}/*`]), + }; + if (rewriteFilter.urls.length > 0) { + session.defaultSession.webRequest.onBeforeRequest(rewriteFilter, (details, callback) => { + try { + const u = new URL(details.url); + const originalHost = u.host; + const local = new URL(getServerOrigin()); + u.protocol = local.protocol; + u.host = local.host; + trackEvent('desktop_oauth_redirect_intercepted', { + host: originalHost, + path: u.pathname, + rewritten_to_port: getServerPort(), + }); + callback({ redirectURL: u.toString() }); + } catch { + callback({}); + } + }); + } + + // Diagnostic: connector callback landing somewhere other than localhost + // means the rewrite missed and the user is stranded off-app. + session.defaultSession.webRequest.onCompleted( + { urls: ['*://*/dashboard/*/connectors/callback*'] }, + (details) => { + try { + const u = new URL(details.url); + if (u.hostname === 'localhost' || u.hostname === '127.0.0.1') return; + trackEvent('desktop_oauth_redirect_missed', { + host: u.host, + path: u.pathname, + status_code: details.statusCode, + }); + } catch { + // ignore malformed URLs + } + } + ); mainWindow.webContents.on('did-fail-load', (_event, errorCode, errorDescription, validatedURL) => { console.error(`Failed to load ${validatedURL}: ${errorDescription} (${errorCode})`); diff --git a/surfsense_desktop/src/preload.ts b/surfsense_desktop/src/preload.ts index 7d72e9da5..97232179c 100644 --- a/surfsense_desktop/src/preload.ts +++ b/surfsense_desktop/src/preload.ts @@ -10,6 +10,14 @@ contextBridge.exposeInMainWorld('electronAPI', { }, openExternal: (url: string) => ipcRenderer.send(IPC_CHANNELS.OPEN_EXTERNAL, url), getAppVersion: () => ipcRenderer.invoke(IPC_CHANNELS.GET_APP_VERSION), + onUpdateDownloaded: (callback: (data: { version: string }) => void) => { + const listener = (_event: unknown, data: { version: string }) => callback(data); + ipcRenderer.on(IPC_CHANNELS.UPDATE_DOWNLOADED, listener); + return () => { + ipcRenderer.removeListener(IPC_CHANNELS.UPDATE_DOWNLOADED, listener); + }; + }, + installUpdateNow: () => ipcRenderer.invoke(IPC_CHANNELS.UPDATE_INSTALL_NOW), onDeepLink: (callback: (url: string) => void) => { const listener = (_event: unknown, url: string) => callback(url); ipcRenderer.on(IPC_CHANNELS.DEEP_LINK, listener); diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index 4512f3396..5cea9b6d2 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -221,10 +221,7 @@ export default async function FreeHubPage() { {/* In-content ad: above the model table */} -