mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
Merge branch 'feat/migrate-electric-to-zero' into electon-desktop
This commit is contained in:
commit
a1119c401f
396 changed files with 39475 additions and 18064 deletions
136
.cursor/skills/system-architecture/SKILL.md
Executable file
136
.cursor/skills/system-architecture/SKILL.md
Executable file
|
|
@ -0,0 +1,136 @@
|
|||
---
|
||||
name: system-architecture
|
||||
description: Design systems with appropriate complexity - no more, no less. Use when the user asks to architect applications, design system boundaries, plan service decomposition, evaluate monolith vs microservices, make scaling decisions, or review structural trade-offs. Applies to new system design, refactoring, and migration planning.
|
||||
---
|
||||
|
||||
# System Architecture
|
||||
|
||||
Design real structures with clear boundaries, explicit trade-offs, and appropriate complexity. Match architecture to actual requirements, not imagined future needs.
|
||||
|
||||
## Workflow
|
||||
|
||||
When the user requests an architecture, follow these steps:
|
||||
|
||||
```
|
||||
Task Progress:
|
||||
- [ ] Step 1: Clarify constraints
|
||||
- [ ] Step 2: Identify domains
|
||||
- [ ] Step 3: Map data flow
|
||||
- [ ] Step 4: Draw boundaries with rationale
|
||||
- [ ] Step 5: Run complexity checklist
|
||||
- [ ] Step 6: Present architecture with trade-offs
|
||||
```
|
||||
|
||||
**Step 1 - Clarify constraints.** Ask about:
|
||||
|
||||
| Constraint | Question | Why it matters |
|
||||
|------------|----------|----------------|
|
||||
| Scale | What's the real load? (users, requests/sec, data size) | Design for 10x current, not 1000x |
|
||||
| Team | How many developers? How many teams? | Deployable units ≤ number of teams |
|
||||
| Lifespan | Prototype? MVP? Long-term product? | Temporary systems need temporary solutions |
|
||||
| Change vectors | What actually varies? | Abstract only where you have evidence of variation |
|
||||
|
||||
**Step 2 - Identify domains.** Group by business capability, not technical layer. Look for things that change for different reasons and at different rates.
|
||||
|
||||
**Step 3 - Map data flow.** Trace: where does data enter → how does it transform → where does it exit? Make the flow obvious.
|
||||
|
||||
**Step 4 - Draw boundaries.** Every boundary needs a reason: different team, different change rate, different compliance requirement, or different scaling need.
|
||||
|
||||
**Step 5 - Run complexity checklist.** Before adding any non-trivial pattern:
|
||||
|
||||
```
|
||||
[ ] Have I tried the simple solution?
|
||||
[ ] Do I have evidence it's insufficient?
|
||||
[ ] Can my team operate this?
|
||||
[ ] Will this still make sense in 6 months?
|
||||
[ ] Can I explain why this complexity is necessary?
|
||||
```
|
||||
|
||||
If any answer is "no", keep it simple.
|
||||
|
||||
**Step 6 - Present the architecture** using the output template below.
|
||||
|
||||
## Output Template
|
||||
|
||||
```markdown
|
||||
### System: [Name]
|
||||
|
||||
**Constraints**:
|
||||
- Scale: [current and expected load]
|
||||
- Team: [size and structure]
|
||||
- Lifespan: [prototype / MVP / long-term]
|
||||
|
||||
**Architecture**:
|
||||
[Component diagram or description of components and their relationships]
|
||||
|
||||
**Data Flow**:
|
||||
[How data enters → transforms → exits]
|
||||
|
||||
**Key Boundaries**:
|
||||
| Boundary | Reason | Change Rate |
|
||||
|----------|--------|-------------|
|
||||
| ... | ... | ... |
|
||||
|
||||
**Trade-offs**:
|
||||
- Chose X over Y because [reason]
|
||||
- Accepted [limitation] to gain [benefit]
|
||||
|
||||
**Complexity Justification**:
|
||||
- [Each non-trivial pattern] → [why it's needed, with evidence]
|
||||
```
|
||||
|
||||
## Core Principles
|
||||
|
||||
1. **Boundaries at real differences.** Separate concerns that change for different reasons and at different rates.
|
||||
2. **Dependencies flow inward.** Core logic depends on nothing. Infrastructure depends on core.
|
||||
3. **Follow the data.** Architecture should make data flow obvious.
|
||||
4. **Design for failure.** Network fails. Databases timeout. Build compensation into the structure.
|
||||
5. **Design for operations.** You will debug this at 3am. Every request needs a trace. Every error needs context for replay.
|
||||
|
||||
For concrete good/bad examples of each principle, see [examples.md](examples.md).
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
| Don't | Do Instead |
|
||||
|-------|------------|
|
||||
| Microservices for a 3-person team | Well-structured monolith |
|
||||
| Event sourcing for CRUD | Simple state storage |
|
||||
| Message queues within the same process | Just call the function |
|
||||
| Distributed transactions | Redesign to avoid, or accept eventual consistency |
|
||||
| Repository wrapping an ORM | Use the ORM directly |
|
||||
| Interfaces with one implementation | Mock at boundaries only |
|
||||
| AbstractFactoryFactoryBean | Just instantiate the thing |
|
||||
| DI containers for simple graphs | Constructor injection is enough |
|
||||
| Clean Architecture for a TODO app | Match layers to actual complexity |
|
||||
| DDD tactics without strategic design | Aggregates need bounded contexts |
|
||||
| Hexagonal ports with one adapter | Just call the database |
|
||||
| CQRS when reads = writes | Add when they diverge |
|
||||
| "We might swap databases" | You won't; rewrite if you do |
|
||||
| "Multi-tenant someday" | Build it when you have tenant #2 |
|
||||
| "Microservices for team scale" | Helps at 50+ engineers, not 4 |
|
||||
|
||||
## Success Criteria
|
||||
|
||||
Your architecture is right-sized when:
|
||||
|
||||
1. **You can draw it** - dependency graph fits on a whiteboard
|
||||
2. **You can explain it** - new team member understands data flow in 30 minutes
|
||||
3. **You can change it** - adding a feature touches 1-3 modules, not 10
|
||||
4. **You can delete it** - removing a component needs no archaeology
|
||||
5. **You can debug it** - tracing a request takes minutes, not hours
|
||||
6. **It matches your team** - deployable units ≤ number of teams
|
||||
|
||||
## When the Simple Solution Isn't Enough
|
||||
|
||||
If the complexity checklist says "yes, scale is real", see [scaling-checklist.md](scaling-checklist.md) for concrete techniques covering caching, async processing, partitioning, horizontal scaling, and multi-region.
|
||||
|
||||
## Iterative Architecture
|
||||
|
||||
Architecture is discovered, not designed upfront:
|
||||
|
||||
1. **Start obvious** - group by domain, not by technical layer
|
||||
2. **Let hotspots emerge** - monitor which modules change together
|
||||
3. **Extract when painful** - split only when the current form causes measurable problems
|
||||
4. **Document decisions** - record why boundaries exist so future you knows what's load-bearing
|
||||
|
||||
Every senior engineer has a graveyard of over-engineered systems they regret. Learn from their pain. Build boring systems that work.
|
||||
120
.cursor/skills/system-architecture/examples.md
Normal file
120
.cursor/skills/system-architecture/examples.md
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# Architecture Examples
|
||||
|
||||
Concrete good/bad examples for each core principle in SKILL.md.
|
||||
|
||||
---
|
||||
|
||||
## Boundaries at Real Differences
|
||||
|
||||
**Good** - Meaningful boundary:
|
||||
```
|
||||
# Users and Billing are separate bounded contexts
|
||||
# - Different teams own them
|
||||
# - Different change cadences (users: weekly, billing: quarterly)
|
||||
# - Different compliance requirements
|
||||
|
||||
src/
|
||||
users/ # User management domain
|
||||
models.py
|
||||
services.py
|
||||
api.py
|
||||
billing/ # Billing domain
|
||||
models.py
|
||||
services.py
|
||||
api.py
|
||||
shared/ # Truly shared utilities
|
||||
auth.py
|
||||
```
|
||||
|
||||
**Bad** - Ceremony without purpose:
|
||||
```
|
||||
# UserService → UserRepository → UserRepositoryImpl
|
||||
# ...when you'll never swap the database
|
||||
|
||||
src/
|
||||
interfaces/
|
||||
IUserRepository.py # One implementation exists
|
||||
repositories/
|
||||
UserRepositoryImpl.py # Wraps SQLAlchemy, which is already a repository
|
||||
services/
|
||||
UserService.py # Just calls the repository
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Dependencies Flow Inward
|
||||
|
||||
**Good** - Clear dependency direction:
|
||||
```
|
||||
# Dependency flows inward: infrastructure → application → domain
|
||||
|
||||
domain/ # Pure business logic, no imports from outer layers
|
||||
order.py # Order entity with business rules
|
||||
|
||||
application/ # Use cases, orchestrates domain
|
||||
place_order.py # Imports from domain/, not infrastructure/
|
||||
|
||||
infrastructure/ # External concerns
|
||||
postgres.py # Implements persistence, imports from application/
|
||||
stripe.py # Implements payments
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Follow the Data
|
||||
|
||||
**Good** - Obvious data flow:
|
||||
```
|
||||
Request → Validate → Transform → Store → Respond
|
||||
|
||||
# Each step is a clear function/module:
|
||||
api/routes.py # Request enters
|
||||
validators.py # Validation
|
||||
transformers.py # Business logic transformation
|
||||
repositories.py # Storage
|
||||
serializers.py # Response shaping
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Design for Failure
|
||||
|
||||
**Good** - Failure-aware design with compensation:
|
||||
```python
|
||||
class OrderService:
|
||||
def place_order(self, order: Order) -> Result:
|
||||
inventory = self.inventory.reserve(order.items)
|
||||
if inventory.failed:
|
||||
return Result.failure("Items unavailable", retry=False)
|
||||
|
||||
payment = self.payments.charge(order.total)
|
||||
if payment.failed:
|
||||
self.inventory.release(inventory.reservation_id) # Compensate
|
||||
return Result.failure("Payment failed", retry=True)
|
||||
|
||||
return Result.success(order)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Design for Operations
|
||||
|
||||
**Good** - Observable architecture:
|
||||
```python
|
||||
@trace
|
||||
def handle_request(request):
|
||||
log.info("Processing", request_id=request.id, user=request.user_id)
|
||||
try:
|
||||
result = process(request)
|
||||
log.info("Completed", request_id=request.id, result=result.status)
|
||||
return result
|
||||
except Exception as e:
|
||||
log.error("Failed", request_id=request.id, error=str(e),
|
||||
context=request.to_dict()) # Full context for replay
|
||||
raise
|
||||
```
|
||||
|
||||
Key elements:
|
||||
- Every request gets a correlation ID
|
||||
- Every service logs with that ID
|
||||
- Every error includes full context for reproduction
|
||||
76
.cursor/skills/system-architecture/scaling-checklist.md
Normal file
76
.cursor/skills/system-architecture/scaling-checklist.md
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Scaling Checklist
|
||||
|
||||
Concrete techniques for when the complexity checklist in SKILL.md confirms scale is a real problem. Apply in order - each level solves the previous level's bottleneck.
|
||||
|
||||
---
|
||||
|
||||
## Level 0: Optimize First
|
||||
|
||||
Before adding infrastructure, exhaust these:
|
||||
|
||||
- [ ] Database queries have proper indexes
|
||||
- [ ] N+1 queries eliminated
|
||||
- [ ] Connection pooling configured
|
||||
- [ ] Slow endpoints profiled and optimized
|
||||
- [ ] Static assets served via CDN
|
||||
|
||||
## Level 1: Read-Heavy
|
||||
|
||||
**Symptom**: Database reads are the bottleneck.
|
||||
|
||||
| Technique | When | Trade-off |
|
||||
|-----------|------|-----------|
|
||||
| Application cache (in-memory) | Small, frequently accessed data | Stale data, memory pressure |
|
||||
| Redis/Memcached | Shared cache across instances | Network hop, cache invalidation complexity |
|
||||
| Read replicas | High read volume, slight staleness OK | Replication lag, eventual consistency |
|
||||
| CDN | Static or semi-static content | Cache invalidation delay |
|
||||
|
||||
## Level 2: Write-Heavy
|
||||
|
||||
**Symptom**: Database writes or processing are the bottleneck.
|
||||
|
||||
| Technique | When | Trade-off |
|
||||
|-----------|------|-----------|
|
||||
| Async task queue (Celery, SQS) | Work can be deferred | Eventual consistency, failure handling |
|
||||
| Write-behind cache | Batch frequent writes | Data loss risk on crash |
|
||||
| Event streaming (Kafka) | Multiple consumers of same data | Operational complexity, ordering guarantees |
|
||||
| CQRS | Reads and writes have diverged significantly | Two models to maintain |
|
||||
|
||||
## Level 3: Traffic Spikes
|
||||
|
||||
**Symptom**: Individual instances can't handle peak load.
|
||||
|
||||
| Technique | When | Trade-off |
|
||||
|-----------|------|-----------|
|
||||
| Horizontal scaling + load balancer | Stateless services | Session management, deploy complexity |
|
||||
| Auto-scaling | Unpredictable traffic patterns | Cold start latency, cost spikes |
|
||||
| Rate limiting | Protect against abuse/spikes | Legitimate users may be throttled |
|
||||
| Circuit breakers | Downstream services degrade | Partial functionality during failures |
|
||||
|
||||
## Level 4: Data Growth
|
||||
|
||||
**Symptom**: Single database can't hold or query all the data efficiently.
|
||||
|
||||
| Technique | When | Trade-off |
|
||||
|-----------|------|-----------|
|
||||
| Table partitioning | Time-series or naturally partitioned data | Query complexity, partition management |
|
||||
| Archival / cold storage | Old data rarely accessed | Access latency for archived data |
|
||||
| Database sharding | Partitioning insufficient, clear shard key exists | Cross-shard queries, operational burden |
|
||||
| Search index (Elasticsearch) | Full-text or complex queries on large datasets | Index lag, another system to operate |
|
||||
|
||||
## Level 5: Multi-Region
|
||||
|
||||
**Symptom**: Users are geographically distributed, latency matters.
|
||||
|
||||
| Technique | When | Trade-off |
|
||||
|-----------|------|-----------|
|
||||
| CDN + edge caching | Static/semi-static content | Cache invalidation |
|
||||
| Read replicas per region | Read-heavy, slight staleness OK | Replication lag |
|
||||
| Active-passive failover | Disaster recovery | Failover time, cost of standby |
|
||||
| Active-active multi-region | True global low-latency required | Conflict resolution, extreme complexity |
|
||||
|
||||
---
|
||||
|
||||
## Decision Rule
|
||||
|
||||
Always start at Level 0. Move to the next level only when you have **measured evidence** that the current level is insufficient. Skipping levels is how you end up with Kafka for a TODO app.
|
||||
2
.github/workflows/desktop-release.yml
vendored
2
.github/workflows/desktop-release.yml
vendored
|
|
@ -57,7 +57,7 @@ jobs:
|
|||
working-directory: surfsense_web
|
||||
env:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${{ vars.NEXT_PUBLIC_ELECTRIC_URL }}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}
|
||||
|
||||
|
|
|
|||
3
.github/workflows/docker-build.yml
vendored
3
.github/workflows/docker-build.yml
vendored
|
|
@ -164,8 +164,7 @@ jobs:
|
|||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_URL=__NEXT_PUBLIC_ELECTRIC_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ELECTRIC_AUTH_MODE=__NEXT_PUBLIC_ELECTRIC_AUTH_MODE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
|
||||
|
||||
- name: Export digest
|
||||
|
|
|
|||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -5,4 +5,4 @@ node_modules/
|
|||
.ruff_cache/
|
||||
.venv
|
||||
.pnpm-store
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
35
.vscode/launch.json
vendored
35
.vscode/launch.json
vendored
|
|
@ -22,7 +22,11 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "${command:python.interpreterPath}"
|
||||
"python": "uv",
|
||||
"pythonArgs": [
|
||||
"run",
|
||||
"python"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Backend: FastAPI (No Reload)",
|
||||
|
|
@ -32,7 +36,11 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "${command:python.interpreterPath}"
|
||||
"python": "uv",
|
||||
"pythonArgs": [
|
||||
"run",
|
||||
"python"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Backend: FastAPI (main.py)",
|
||||
|
|
@ -41,14 +49,19 @@
|
|||
"program": "${workspaceFolder}/surfsense_backend/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend"
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "uv",
|
||||
"pythonArgs": [
|
||||
"run",
|
||||
"python"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Next.js",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/surfsense_web",
|
||||
"runtimeExecutable": "npm",
|
||||
"runtimeExecutable": "pnpm",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"console": "integratedTerminal",
|
||||
"serverReadyAction": {
|
||||
|
|
@ -62,7 +75,7 @@
|
|||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/surfsense_web",
|
||||
"runtimeExecutable": "npm",
|
||||
"runtimeExecutable": "pnpm",
|
||||
"runtimeArgs": ["run", "debug:server"],
|
||||
"console": "integratedTerminal",
|
||||
"serverReadyAction": {
|
||||
|
|
@ -87,7 +100,11 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "${command:python.interpreterPath}"
|
||||
"python": "uv",
|
||||
"pythonArgs": [
|
||||
"run",
|
||||
"python"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Celery: Beat Scheduler",
|
||||
|
|
@ -103,7 +120,11 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "${command:python.interpreterPath}"
|
||||
"python": "uv",
|
||||
"pythonArgs": [
|
||||
"run",
|
||||
"python"
|
||||
]
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
|
|
|
|||
22
README.es.md
22
README.es.md
|
|
@ -27,11 +27,18 @@ SurfSense es un agente de investigación de IA altamente personalizable, conecta
|
|||
|
||||
|
||||
|
||||
# Video
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Ejemplo de Podcast
|
||||
## Ejemplo de Agente de Video
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/cc977e6d-8292-4ffe-abb8-3b0560ef5562
|
||||
|
||||
|
||||
|
||||
## Ejemplo de Agente de Podcast
|
||||
|
||||
https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||
|
||||
|
|
@ -46,20 +53,25 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
2. Conecta tus conectores y sincroniza. Activa la sincronización periódica para mantenerlos actualizados.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/59da61d7-da05-4576-b7c0-dbc09f5985e8" alt="Conectores" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Conectores" /></p>
|
||||
|
||||
3. Mientras se indexan los datos de los conectores, sube documentos.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/d1e8b2e2-9eac-41d8-bdc0-f0cdc405d128" alt="Subir Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Subir Documentos" /></p>
|
||||
|
||||
4. Una vez que todo esté indexado, pregunta lo que quieras (Casos de uso):
|
||||
|
||||
- Generación de videos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Generación de Videos" /></p>
|
||||
|
||||
- Búsqueda básica y citaciones
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Búsqueda y Citación" /></p>
|
||||
|
||||
- QNA con mención de documentos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="QNA con Mención de Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="QNA con Mención de Documentos" /></p>
|
||||
|
||||
- Generación de informes y exportaciones (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto plano)
|
||||
|
|
@ -133,6 +145,8 @@ Para Docker Compose, instalación manual y otras opciones de despliegue, consult
|
|||
| Soporte Universal de LLM | 100+ LLMs, 6000+ modelos de embeddings, todos los principales rerankers vía OpenAI spec y LiteLLM |
|
||||
| Privacidad Primero | Soporte completo de LLM local (vLLM, Ollama) tus datos son tuyos |
|
||||
| Colaboración en Equipo | RBAC con roles de Propietario / Admin / Editor / Visor, chat en tiempo real e hilos de comentarios |
|
||||
| Generación de Videos | Genera videos con narración y visuales |
|
||||
| Generación de Presentaciones | Crea presentaciones editables basadas en diapositivas |
|
||||
| Generación de Podcasts | Podcast de 3 min en menos de 20 segundos; múltiples proveedores TTS (OpenAI, Azure, Kokoro) |
|
||||
| Extensión de Navegador | Extensión multi-navegador para guardar cualquier página web, incluyendo páginas protegidas por autenticación |
|
||||
| 25+ Conectores | Motores de búsqueda, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord y [más](#fuentes-externas) |
|
||||
|
|
|
|||
22
README.hi.md
22
README.hi.md
|
|
@ -27,11 +27,18 @@ SurfSense एक अत्यधिक अनुकूलन योग्य AI
|
|||
|
||||
|
||||
|
||||
# वीडियो
|
||||
# डेमो
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## पॉडकास्ट नमूना
|
||||
## वीडियो एजेंट नमूना
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/cc977e6d-8292-4ffe-abb8-3b0560ef5562
|
||||
|
||||
|
||||
|
||||
## पॉडकास्ट एजेंट नमूना
|
||||
|
||||
https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||
|
||||
|
|
@ -46,20 +53,25 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
2. अपने कनेक्टर जोड़ें और सिंक करें। कनेक्टर्स को अपडेट रखने के लिए आवधिक सिंकिंग सक्षम करें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/59da61d7-da05-4576-b7c0-dbc09f5985e8" alt="कनेक्टर्स" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="कनेक्टर्स" /></p>
|
||||
|
||||
3. जब तक कनेक्टर्स का डेटा इंडेक्स हो रहा है, दस्तावेज़ अपलोड करें।
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/d1e8b2e2-9eac-41d8-bdc0-f0cdc405d128" alt="दस्तावेज़ अपलोड करें" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="दस्तावेज़ अपलोड करें" /></p>
|
||||
|
||||
4. सब कुछ इंडेक्स हो जाने के बाद, कुछ भी पूछें (उपयोग के मामले):
|
||||
|
||||
- वीडियो जनरेशन
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="वीडियो जनरेशन" /></p>
|
||||
|
||||
- बेसिक सर्च और उद्धरण
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="सर्च और उद्धरण" /></p>
|
||||
|
||||
- दस्तावेज़ मेंशन QNA
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="दस्तावेज़ मेंशन QNA" /></p>
|
||||
|
||||
- रिपोर्ट जनरेशन और एक्सपोर्ट (PDF, DOCX, HTML, LaTeX, EPUB, ODT, सादा टेक्स्ट)
|
||||
|
|
@ -133,6 +145,8 @@ Docker Compose, मैनुअल इंस्टॉलेशन और अन
|
|||
| यूनिवर्सल LLM सपोर्ट | 100+ LLMs, 6000+ एम्बेडिंग मॉडल, सभी प्रमुख रीरैंकर्स OpenAI spec और LiteLLM के माध्यम से |
|
||||
| प्राइवेसी फर्स्ट | पूर्ण लोकल LLM सपोर्ट (vLLM, Ollama) आपका डेटा आपका रहता है |
|
||||
| टीम सहयोग | मालिक / एडमिन / संपादक / दर्शक भूमिकाओं के साथ RBAC, रीयल-टाइम चैट और कमेंट थ्रेड |
|
||||
| वीडियो जनरेशन | नैरेशन और विज़ुअल के साथ वीडियो बनाएं |
|
||||
| प्रेजेंटेशन जनरेशन | संपादन योग्य, स्लाइड आधारित प्रेजेंटेशन बनाएं |
|
||||
| पॉडकास्ट जनरेशन | 20 सेकंड से कम में 3 मिनट का पॉडकास्ट; कई TTS प्रदाता (OpenAI, Azure, Kokoro) |
|
||||
| ब्राउज़र एक्सटेंशन | किसी भी वेबपेज को सहेजने के लिए क्रॉस-ब्राउज़र एक्सटेंशन, प्रमाणीकरण सुरक्षित पेज सहित |
|
||||
| 25+ कनेक्टर्स | सर्च इंजन, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord और [अधिक](#बाहरी-स्रोत) |
|
||||
|
|
|
|||
23
README.md
23
README.md
|
|
@ -27,11 +27,19 @@ SurfSense is a highly customizable AI research agent, connected to external sour
|
|||
|
||||
|
||||
|
||||
# Video
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Podcast Sample
|
||||
## Video Agent Sample
|
||||
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/012a7ffa-6f76-4f06-9dda-7632b470057a
|
||||
|
||||
|
||||
|
||||
## Podcast Agent Sample
|
||||
|
||||
https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||
|
||||
|
|
@ -46,20 +54,25 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
2. Connect your connectors and sync. Enable periodic syncing to keep connectors synced.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/59da61d7-da05-4576-b7c0-dbc09f5985e8" alt="Connectors" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Connectors" /></p>
|
||||
|
||||
3. Till connectors data index, upload Documents.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/d1e8b2e2-9eac-41d8-bdc0-f0cdc405d128" alt="Upload Documents" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Upload Documents" /></p>
|
||||
|
||||
4. Once everything is indexed, Ask Away (Use Cases):
|
||||
|
||||
- Video Generation
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Search and Citation" /></p>
|
||||
|
||||
- Basic search and citation
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Search and Citation" /></p>
|
||||
|
||||
- Document Mention QNA
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="Document Mention QNA" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="Document Mention QNA" /></p>
|
||||
|
||||
- Report Generations and Exports (PDF, DOCX, HTML, LaTeX, EPUB, ODT, Plain Text)
|
||||
|
|
@ -133,6 +146,8 @@ For Docker Compose, manual installation, and other deployment options, see the [
|
|||
| Universal LLM Support | 100+ LLMs, 6000+ embedding models, all major rerankers via OpenAI spec & LiteLLM |
|
||||
| Privacy First | Full local LLM support (vLLM, Ollama) your data stays yours |
|
||||
| Team Collaboration | RBAC with Owner / Admin / Editor / Viewer roles, real time chat & comment threads |
|
||||
| Video Generation | Generate videos with narration and visuals |
|
||||
| Presentation Generation | Create editable, slide based presentations |
|
||||
| Podcast Generation | 3 min podcast in under 20 seconds; multiple TTS providers (OpenAI, Azure, Kokoro) |
|
||||
| Browser Extension | Cross browser extension to save any webpage, including auth protected pages |
|
||||
| 25+ Connectors | Search Engines, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord & [more](#external-sources) |
|
||||
|
|
|
|||
|
|
@ -27,11 +27,18 @@ SurfSense é um agente de pesquisa de IA altamente personalizável, conectado a
|
|||
|
||||
|
||||
|
||||
# Vídeo
|
||||
# Demo
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## Exemplo de Podcast
|
||||
## Exemplo de Agente de Vídeo
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/cc977e6d-8292-4ffe-abb8-3b0560ef5562
|
||||
|
||||
|
||||
|
||||
## Exemplo de Agente de Podcast
|
||||
|
||||
https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||
|
||||
|
|
@ -46,20 +53,25 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
2. Conecte seus conectores e sincronize. Ative a sincronização periódica para manter os conectores atualizados.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/59da61d7-da05-4576-b7c0-dbc09f5985e8" alt="Conectores" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="Conectores" /></p>
|
||||
|
||||
3. Enquanto os dados dos conectores são indexados, faça upload de documentos.
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/d1e8b2e2-9eac-41d8-bdc0-f0cdc405d128" alt="Upload de Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="Upload de Documentos" /></p>
|
||||
|
||||
4. Quando tudo estiver indexado, pergunte o que quiser (Casos de uso):
|
||||
|
||||
- Geração de vídeos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="Geração de Vídeos" /></p>
|
||||
|
||||
- Busca básica e citações
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="Busca e Citação" /></p>
|
||||
|
||||
- QNA com menção de documentos
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="QNA com Menção de Documentos" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="QNA com Menção de Documentos" /></p>
|
||||
|
||||
- Geração de relatórios e exportações (PDF, DOCX, HTML, LaTeX, EPUB, ODT, texto simples)
|
||||
|
|
@ -133,6 +145,8 @@ Para Docker Compose, instalação manual e outras opções de implantação, con
|
|||
| Suporte Universal de LLM | 100+ LLMs, 6000+ modelos de embeddings, todos os principais rerankers via OpenAI spec e LiteLLM |
|
||||
| Privacidade em Primeiro Lugar | Suporte completo a LLM local (vLLM, Ollama) seus dados ficam com você |
|
||||
| Colaboração em Equipe | RBAC com papéis de Proprietário / Admin / Editor / Visualizador, chat em tempo real e threads de comentários |
|
||||
| Geração de Vídeos | Gera vídeos com narração e visuais |
|
||||
| Geração de Apresentações | Cria apresentações editáveis baseadas em slides |
|
||||
| Geração de Podcasts | Podcast de 3 min em menos de 20 segundos; múltiplos provedores TTS (OpenAI, Azure, Kokoro) |
|
||||
| Extensão de Navegador | Extensão multi-navegador para salvar qualquer página web, incluindo páginas protegidas por autenticação |
|
||||
| 25+ Conectores | Mecanismos de busca, Google Drive, Slack, Teams, Jira, Notion, GitHub, Discord e [mais](#fontes-externas) |
|
||||
|
|
|
|||
|
|
@ -27,11 +27,18 @@ SurfSense 是一个高度可定制的 AI 研究助手,可以连接外部数据
|
|||
|
||||
|
||||
|
||||
# 视频
|
||||
# 演示
|
||||
|
||||
https://github.com/user-attachments/assets/cc0c84d3-1f2f-4f7a-b519-2ecce22310b1
|
||||
|
||||
## 播客示例
|
||||
## 视频代理示例
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/cc977e6d-8292-4ffe-abb8-3b0560ef5562
|
||||
|
||||
|
||||
|
||||
## 播客代理示例
|
||||
|
||||
https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
||||
|
||||
|
|
@ -46,20 +53,25 @@ https://github.com/user-attachments/assets/a0a16566-6967-4374-ac51-9b3e07fbecd7
|
|||
|
||||
2. 连接您的连接器并同步。启用定期同步以保持连接器数据更新。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/59da61d7-da05-4576-b7c0-dbc09f5985e8" alt="连接器" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/0740f351-23fa-4909-9880-70aa1dcc1df7" alt="连接器" /></p>
|
||||
|
||||
3. 在连接器数据索引期间,上传文档。
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/d1e8b2e2-9eac-41d8-bdc0-f0cdc405d128" alt="上传文档" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/daf3dbae-ef86-4e86-82ea-fcbcad988761" alt="上传文档" /></p>
|
||||
|
||||
4. 一切索引完成后,尽管提问(使用场景):
|
||||
|
||||
- 视频生成
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/af85c0f3-6cfd-4757-9706-07fd5e32c857" alt="视频生成" /></p>
|
||||
|
||||
- 基本搜索和引用
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/81e797a1-e01a-4003-8e60-0a0b3a9789df" alt="搜索和引用" /></p>
|
||||
|
||||
- 文档提及问答
|
||||
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/65c3bf06-1d46-4dd5-b169-4d934c9b6798" alt="文档提及问答" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/be958295-0a8c-4707-998c-9fe1f1c007be" alt="文档提及问答" /></p>
|
||||
|
||||
- 报告生成和导出(PDF、DOCX、HTML、LaTeX、EPUB、ODT、纯文本)
|
||||
|
|
@ -133,6 +145,8 @@ irm https://raw.githubusercontent.com/MODSetter/SurfSense/main/docker/scripts/in
|
|||
| 通用 LLM 支持 | 100+ LLM、6000+ 嵌入模型、所有主流重排序器,通过 OpenAI spec 和 LiteLLM |
|
||||
| 隐私优先 | 完整本地 LLM 支持(vLLM、Ollama),您的数据由您掌控 |
|
||||
| 团队协作 | RBAC 角色控制(所有者/管理员/编辑者/查看者),实时聊天和评论线程 |
|
||||
| 视频生成 | 生成带有旁白和视觉效果的视频 |
|
||||
| 演示文稿生成 | 创建可编辑的幻灯片式演示文稿 |
|
||||
| 播客生成 | 20 秒内生成 3 分钟播客;多种 TTS 提供商(OpenAI、Azure、Kokoro) |
|
||||
| 浏览器扩展 | 跨浏览器扩展,保存任何网页,包括需要身份验证的页面 |
|
||||
| 25+ 连接器 | 搜索引擎、Google Drive、Slack、Teams、Jira、Notion、GitHub、Discord 等[更多](#外部数据源) |
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
|
||||
# BACKEND_PORT=8929
|
||||
# FRONTEND_PORT=3929
|
||||
# ELECTRIC_PORT=5929
|
||||
# ZERO_CACHE_PORT=5929
|
||||
# SEARXNG_PORT=8888
|
||||
# FLOWER_PORT=5555
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -57,7 +58,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
|
||||
# NEXT_PUBLIC_ELECTRIC_AUTH_MODE=insecure
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / Reverse Proxy
|
||||
|
|
@ -70,8 +70,35 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# NEXT_FRONTEND_URL=https://app.yourdomain.com
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_ELECTRIC_URL=https://electric.yourdomain.com
|
||||
# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Zero-cache (real-time sync)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Defaults work out of the box for Docker deployments.
|
||||
# Change ZERO_ADMIN_PASSWORD for security in production.
|
||||
|
||||
# ZERO_ADMIN_PASSWORD=surfsense-zero-admin
|
||||
# Full override for the Zero → Postgres connection URLs.
|
||||
# Leave commented out to use the Docker-managed `db` container (default).
|
||||
# ZERO_UPSTREAM_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
# ZERO_CVR_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
# ZERO_CHANGE_DB=postgresql://surfsense:surfsense@db:5432/surfsense
|
||||
|
||||
# ZERO_QUERY_URL: where zero-cache forwards query requests for resolution.
|
||||
# ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# Default: Docker service networking (http://frontend:3000/api/zero/...).
|
||||
# Override when running the frontend outside Docker:
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain:
|
||||
# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Database (defaults work out of the box, change for security)
|
||||
|
|
@ -100,19 +127,6 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# Supports TLS: rediss://:password@host:6380/0
|
||||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Electric SQL (real-time sync credentials)
|
||||
# ------------------------------------------------------------------------------
|
||||
# These must match on the db, backend, and electric services.
|
||||
# Change for security; defaults work out of the box.
|
||||
|
||||
# ELECTRIC_DB_USER=electric
|
||||
# ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Full override for the Electric → Postgres connection URL.
|
||||
# Leave commented out to use the Docker-managed `db` container (default).
|
||||
# Uncomment and set `db` to `host.docker.internal` when pointing Electric at a local Postgres instance (e.g. Postgres.app on macOS):
|
||||
# ELECTRIC_DATABASE_URL=postgresql://electric:electric_password@db:5432/surfsense?sslmode=disable
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
@ -199,6 +213,16 @@ STT_SERVICE=local/base
|
|||
# COMPOSIO_ENABLED=TRUE
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG (bundled web search — works out of the box, no config needed)
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG provides web search to all search spaces automatically.
|
||||
# To access the SearXNG UI directly: http://localhost:8888
|
||||
# To disable the service entirely: docker compose up --scale searxng=0
|
||||
# To point at your own SearXNG instance instead of the bundled one:
|
||||
# SEARXNG_DEFAULT_HOST=http://your-searxng:8080
|
||||
# SEARXNG_SECRET=surfsense-searxng-secret
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Daytona Sandbox (optional — cloud code execution for the deep agent)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -18,13 +18,10 @@ services:
|
|||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
- POSTGRES_USER=${DB_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${DB_NAME:-surfsense}
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres} -d ${DB_NAME:-surfsense}"]
|
||||
|
|
@ -57,6 +54,20 @@ services:
|
|||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
searxng:
|
||||
image: searxng/searxng:2026.3.13-3c1f68c59
|
||||
ports:
|
||||
- "${SEARXNG_PORT:-8888}:8080"
|
||||
volumes:
|
||||
- ./searxng:/etc/searxng
|
||||
environment:
|
||||
- SEARXNG_SECRET=${SEARXNG_SECRET:-surfsense-searxng-secret}
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:8080/healthz"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
backend:
|
||||
build: ../surfsense_backend
|
||||
ports:
|
||||
|
|
@ -77,10 +88,9 @@ services:
|
|||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
|
||||
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
# - DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
# - DAYTONA_API_KEY=${DAYTONA_API_KEY:-}
|
||||
|
|
@ -92,6 +102,8 @@ services:
|
|||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
searxng:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 15s
|
||||
|
|
@ -113,8 +125,7 @@ services:
|
|||
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- ELECTRIC_DB_USER=${ELECTRIC_DB_USER:-electric}
|
||||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
- SERVICE_ROLE=worker
|
||||
depends_on:
|
||||
db:
|
||||
|
|
@ -158,20 +169,28 @@ services:
|
|||
# - redis
|
||||
# - celery_worker
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
zero-cache:
|
||||
image: rocicorp/zero:0.26.2
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5133}:3000"
|
||||
- "${ZERO_CACHE_PORT:-4848}:4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- DATABASE_URL=${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ELECTRIC_INSECURE=true
|
||||
- ELECTRIC_WRITE_TO_PG_MODE=direct
|
||||
- ZERO_UPSTREAM_DB=${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_CVR_DB=${ZERO_CVR_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_CHANGE_DB=${ZERO_CHANGE_DB:-postgresql://${DB_USER:-postgres}:${DB_PASSWORD:-postgres}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
- ZERO_REPLICA_FILE=/data/zero.db
|
||||
- ZERO_ADMIN_PASSWORD=${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
|
||||
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
|
@ -183,8 +202,7 @@ services:
|
|||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:5133}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
|
|
@ -193,7 +211,7 @@ services:
|
|||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
|
|
@ -205,3 +223,5 @@ volumes:
|
|||
name: surfsense-dev-redis
|
||||
shared_temp:
|
||||
name: surfsense-dev-shared-temp
|
||||
zero_cache_data:
|
||||
name: surfsense-dev-zero-cache
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ services:
|
|||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgresql.conf:/etc/postgresql/postgresql.conf:ro
|
||||
- ./scripts/init-electric-user.sh:/docker-entrypoint-initdb.d/init-electric-user.sh:ro
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USER:-surfsense}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-surfsense}
|
||||
POSTGRES_DB: ${DB_NAME:-surfsense}
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
command: postgres -c config_file=/etc/postgresql/postgresql.conf
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
|
|
@ -42,6 +39,19 @@ services:
|
|||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
searxng:
|
||||
image: searxng/searxng:2026.3.13-3c1f68c59
|
||||
volumes:
|
||||
- ./searxng:/etc/searxng
|
||||
environment:
|
||||
SEARXNG_SECRET: ${SEARXNG_SECRET:-surfsense-searxng-secret}
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:8080/healthz"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
backend:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}
|
||||
ports:
|
||||
|
|
@ -59,9 +69,8 @@ services:
|
|||
PYTHONPATH: /app
|
||||
UVICORN_LOOP: asyncio
|
||||
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
# DAYTONA_SANDBOX_ENABLED: "TRUE"
|
||||
# DAYTONA_API_KEY: ${DAYTONA_API_KEY:-}
|
||||
|
|
@ -75,6 +84,8 @@ services:
|
|||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
searxng:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
|
|
@ -96,8 +107,7 @@ services:
|
|||
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
ELECTRIC_DB_USER: ${ELECTRIC_DB_USER:-electric}
|
||||
ELECTRIC_DB_PASSWORD: ${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
SERVICE_ROLE: worker
|
||||
depends_on:
|
||||
db:
|
||||
|
|
@ -148,20 +158,28 @@ services:
|
|||
# - celery_worker
|
||||
# restart: unless-stopped
|
||||
|
||||
electric:
|
||||
image: electricsql/electric:1.4.10
|
||||
zero-cache:
|
||||
image: rocicorp/zero:0.26.2
|
||||
ports:
|
||||
- "${ELECTRIC_PORT:-5929}:3000"
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
DATABASE_URL: ${ELECTRIC_DATABASE_URL:-postgresql://${ELECTRIC_DB_USER:-electric}:${ELECTRIC_DB_PASSWORD:-electric_password}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ELECTRIC_INSECURE: "true"
|
||||
ELECTRIC_WRITE_TO_PG_MODE: direct
|
||||
ZERO_UPSTREAM_DB: ${ZERO_UPSTREAM_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_CVR_DB: ${ZERO_CVR_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_CHANGE_DB: ${ZERO_CHANGE_DB:-postgresql://${DB_USER:-surfsense}:${DB_PASSWORD:-surfsense}@${DB_HOST:-db}:${DB_PORT:-5432}/${DB_NAME:-surfsense}?sslmode=${DB_SSLMODE:-disable}}
|
||||
ZERO_REPLICA_FILE: /data/zero.db
|
||||
ZERO_ADMIN_PASSWORD: ${ZERO_ADMIN_PASSWORD:-surfsense-zero-admin}
|
||||
ZERO_QUERY_URL: ${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
ZERO_MUTATE_URL: ${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:3000/v1/health"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:4848/keepalive"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
|
@ -172,17 +190,16 @@ services:
|
|||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
environment:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}}
|
||||
NEXT_PUBLIC_ELECTRIC_URL: ${NEXT_PUBLIC_ELECTRIC_URL:-http://localhost:${ELECTRIC_PORT:-5929}}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
NEXT_PUBLIC_ELECTRIC_AUTH_MODE: ${NEXT_PUBLIC_ELECTRIC_AUTH_MODE:-insecure}
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
electric:
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
|
|
@ -193,3 +210,5 @@ volumes:
|
|||
name: surfsense-redis
|
||||
shared_temp:
|
||||
name: surfsense-shared-temp
|
||||
zero_cache_data:
|
||||
name: surfsense-zero-cache
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# PostgreSQL configuration for Electric SQL
|
||||
# PostgreSQL configuration for SurfSense
|
||||
# This file is mounted into the PostgreSQL container
|
||||
|
||||
listen_addresses = '*'
|
||||
max_connections = 200
|
||||
shared_buffers = 256MB
|
||||
|
||||
# Enable logical replication (required for Electric SQL)
|
||||
# Enable logical replication (required for Zero-cache real-time sync)
|
||||
wal_level = logical
|
||||
max_replication_slots = 10
|
||||
max_wal_senders = 10
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
#!/bin/sh
|
||||
# Creates the Electric SQL replication user on first DB initialization.
|
||||
# Idempotent — safe to run alongside Alembic migration 66.
|
||||
|
||||
set -e
|
||||
|
||||
ELECTRIC_DB_USER="${ELECTRIC_DB_USER:-electric}"
|
||||
ELECTRIC_DB_PASSWORD="${ELECTRIC_DB_PASSWORD:-electric_password}"
|
||||
|
||||
echo "Creating Electric SQL replication user: $ELECTRIC_DB_USER"
|
||||
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_user WHERE usename = '$ELECTRIC_DB_USER') THEN
|
||||
CREATE USER $ELECTRIC_DB_USER WITH REPLICATION PASSWORD '$ELECTRIC_DB_PASSWORD';
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
|
||||
GRANT CONNECT ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT CREATE ON DATABASE $POSTGRES_DB TO $ELECTRIC_DB_USER;
|
||||
GRANT USAGE ON SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO $ELECTRIC_DB_USER;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO $ELECTRIC_DB_USER;
|
||||
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
CREATE PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
EOSQL
|
||||
|
||||
echo "Electric SQL user '$ELECTRIC_DB_USER' and publication created successfully"
|
||||
|
|
@ -103,13 +103,15 @@ Write-Step "Downloading SurfSense files"
|
|||
Write-Info "Installation directory: $InstallDir"
|
||||
|
||||
New-Item -ItemType Directory -Path "$InstallDir\scripts" -Force | Out-Null
|
||||
New-Item -ItemType Directory -Path "$InstallDir\searxng" -Force | Out-Null
|
||||
|
||||
$Files = @(
|
||||
@{ Src = "docker/docker-compose.yml"; Dest = "docker-compose.yml" }
|
||||
@{ Src = "docker/.env.example"; Dest = ".env.example" }
|
||||
@{ Src = "docker/postgresql.conf"; Dest = "postgresql.conf" }
|
||||
@{ Src = "docker/scripts/init-electric-user.sh"; Dest = "scripts/init-electric-user.sh" }
|
||||
@{ Src = "docker/scripts/migrate-database.ps1"; Dest = "scripts/migrate-database.ps1" }
|
||||
@{ Src = "docker/searxng/settings.yml"; Dest = "searxng/settings.yml" }
|
||||
@{ Src = "docker/searxng/limiter.toml"; Dest = "searxng/limiter.toml" }
|
||||
)
|
||||
|
||||
foreach ($f in $Files) {
|
||||
|
|
|
|||
|
|
@ -102,13 +102,15 @@ wait_for_pg() {
|
|||
step "Downloading SurfSense files"
|
||||
info "Installation directory: ${INSTALL_DIR}"
|
||||
mkdir -p "${INSTALL_DIR}/scripts"
|
||||
mkdir -p "${INSTALL_DIR}/searxng"
|
||||
|
||||
FILES=(
|
||||
"docker/docker-compose.yml:docker-compose.yml"
|
||||
"docker/.env.example:.env.example"
|
||||
"docker/postgresql.conf:postgresql.conf"
|
||||
"docker/scripts/init-electric-user.sh:scripts/init-electric-user.sh"
|
||||
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
|
||||
"docker/searxng/settings.yml:searxng/settings.yml"
|
||||
"docker/searxng/limiter.toml:searxng/limiter.toml"
|
||||
)
|
||||
|
||||
for entry in "${FILES[@]}"; do
|
||||
|
|
@ -119,7 +121,6 @@ for entry in "${FILES[@]}"; do
|
|||
|| error "Failed to download ${dest}. Check your internet connection and try again."
|
||||
done
|
||||
|
||||
chmod +x "${INSTALL_DIR}/scripts/init-electric-user.sh"
|
||||
chmod +x "${INSTALL_DIR}/scripts/migrate-database.sh"
|
||||
success "All files downloaded to ${INSTALL_DIR}/"
|
||||
|
||||
|
|
|
|||
5
docker/searxng/limiter.toml
Normal file
5
docker/searxng/limiter.toml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
[botdetection.ip_limit]
|
||||
link_token = false
|
||||
|
||||
[botdetection.ip_lists]
|
||||
pass_ip = ["0.0.0.0/0"]
|
||||
90
docker/searxng/settings.yml
Normal file
90
docker/searxng/settings.yml
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
use_default_settings:
|
||||
engines:
|
||||
remove:
|
||||
- ahmia
|
||||
- torch
|
||||
- qwant
|
||||
- qwant news
|
||||
- qwant images
|
||||
- qwant videos
|
||||
- mojeek
|
||||
- mojeek images
|
||||
- mojeek news
|
||||
|
||||
server:
|
||||
secret_key: "override-me-via-env"
|
||||
limiter: false
|
||||
image_proxy: false
|
||||
method: "GET"
|
||||
default_http_headers:
|
||||
X-Robots-Tag: "noindex, nofollow"
|
||||
|
||||
search:
|
||||
formats:
|
||||
- html
|
||||
- json
|
||||
default_lang: "auto"
|
||||
autocomplete: ""
|
||||
safe_search: 0
|
||||
ban_time_on_fail: 5
|
||||
max_ban_time_on_fail: 120
|
||||
suspended_times:
|
||||
SearxEngineAccessDenied: 3600
|
||||
SearxEngineCaptcha: 3600
|
||||
SearxEngineTooManyRequests: 600
|
||||
cf_SearxEngineCaptcha: 7200
|
||||
cf_SearxEngineAccessDenied: 3600
|
||||
recaptcha_SearxEngineCaptcha: 7200
|
||||
|
||||
ui:
|
||||
static_use_hash: true
|
||||
|
||||
outgoing:
|
||||
request_timeout: 12.0
|
||||
max_request_timeout: 20.0
|
||||
pool_connections: 100
|
||||
pool_maxsize: 20
|
||||
enable_http2: true
|
||||
extra_proxy_timeout: 10
|
||||
retries: 1
|
||||
# Uncomment and set your residential proxy URL to route search engine requests through it.
|
||||
# Format: http://<username>:<base64_password>@<hostname>:<port>/
|
||||
#
|
||||
# proxies:
|
||||
# all://:
|
||||
# - http://user:pass@proxy-host:port/
|
||||
|
||||
engines:
|
||||
- name: google
|
||||
disabled: false
|
||||
weight: 1.2
|
||||
retry_on_http_error: [429, 503]
|
||||
- name: duckduckgo
|
||||
disabled: false
|
||||
weight: 1.1
|
||||
retry_on_http_error: [429, 503]
|
||||
- name: brave
|
||||
disabled: false
|
||||
weight: 1.0
|
||||
retry_on_http_error: [429, 503]
|
||||
- name: bing
|
||||
disabled: false
|
||||
weight: 0.9
|
||||
retry_on_http_error: [429, 503]
|
||||
- name: wikipedia
|
||||
disabled: false
|
||||
weight: 0.8
|
||||
- name: stackoverflow
|
||||
disabled: false
|
||||
weight: 0.7
|
||||
- name: yahoo
|
||||
disabled: false
|
||||
weight: 0.7
|
||||
retry_on_http_error: [429, 503]
|
||||
- name: wikidata
|
||||
disabled: false
|
||||
weight: 0.6
|
||||
- name: currency
|
||||
disabled: false
|
||||
- name: ddg definitions
|
||||
disabled: false
|
||||
|
|
@ -12,9 +12,10 @@ REDIS_APP_URL=redis://localhost:6379/0
|
|||
# Optional: TTL in seconds for connector indexing lock key
|
||||
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
|
||||
|
||||
#Electric(for migrations only)
|
||||
ELECTRIC_DB_USER=electric
|
||||
ELECTRIC_DB_PASSWORD=electric_password
|
||||
# Platform Web Search (SearXNG)
|
||||
# Set this to enable built-in web search. Docker Compose sets it automatically.
|
||||
# Only uncomment if running the backend outside Docker (e.g. uvicorn on host).
|
||||
# SEARXNG_DEFAULT_HOST=http://localhost:8888
|
||||
|
||||
# Periodic task interval
|
||||
# # Run every minute (default)
|
||||
|
|
@ -99,7 +100,8 @@ TEAMS_CLIENT_ID=your_teams_client_id_here
|
|||
TEAMS_CLIENT_SECRET=your_teams_client_secret_here
|
||||
TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
|
||||
#Composio Coonnector
|
||||
# Composio Connector
|
||||
# NOTE: Disable "Mask Connected Account Secrets" in Composio dashboard (Settings → Project Settings) for Google indexing to work.
|
||||
COMPOSIO_API_KEY=your_api_key_here
|
||||
COMPOSIO_ENABLED=TRUE
|
||||
COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
|
|
|||
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
|
@ -6,6 +6,7 @@ __pycache__/
|
|||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
podcasts/
|
||||
video_presentation_audio/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
celerybeat-schedule*
|
||||
|
|
|
|||
|
|
@ -25,13 +25,6 @@ database_url = os.getenv("DATABASE_URL")
|
|||
if database_url:
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Electric SQL user credentials - centralized configuration for migrations
|
||||
# These are used by migrations that set up Electric SQL replication
|
||||
config.set_main_option("electric_db_user", os.getenv("ELECTRIC_DB_USER", "electric"))
|
||||
config.set_main_option(
|
||||
"electric_db_password", os.getenv("ELECTRIC_DB_PASSWORD", "electric_password")
|
||||
)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
|
|
|||
|
|
@ -30,21 +30,25 @@ def upgrade() -> None:
|
|||
"ix_notifications_user_read_type_created",
|
||||
"notifications",
|
||||
["user_id", "read", "type", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_user_space_created",
|
||||
"notifications",
|
||||
["user_id", "search_space_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_type",
|
||||
"notifications",
|
||||
["type"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_search_space_id",
|
||||
"notifications",
|
||||
["search_space_id"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
"""Add video_presentations table and video_presentation_status enum
|
||||
|
||||
Revision ID: 107
|
||||
Revises: 106
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ENUM, JSONB
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "107"
|
||||
down_revision: str | None = "106"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
video_presentation_status_enum = ENUM(
|
||||
"pending",
|
||||
"generating",
|
||||
"ready",
|
||||
"failed",
|
||||
name="video_presentation_status",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE video_presentation_status AS ENUM ('pending', 'generating', 'ready', 'failed');
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text("SELECT 1 FROM information_schema.tables WHERE table_name = 'video_presentations'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.create_table(
|
||||
"video_presentations",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("title", sa.String(length=500), nullable=False),
|
||||
sa.Column("slides", JSONB(), nullable=True),
|
||||
sa.Column("scene_codes", JSONB(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
video_presentation_status_enum,
|
||||
server_default="ready",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("thread_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"],
|
||||
["searchspaces.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["thread_id"],
|
||||
["new_chat_threads.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_status",
|
||||
"video_presentations",
|
||||
["status"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_thread_id",
|
||||
"video_presentations",
|
||||
["thread_id"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_video_presentations_created_at",
|
||||
"video_presentations",
|
||||
["created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_video_presentations_created_at", table_name="video_presentations")
|
||||
op.drop_index("ix_video_presentations_thread_id", table_name="video_presentations")
|
||||
op.drop_index("ix_video_presentations_status", table_name="video_presentations")
|
||||
op.drop_table("video_presentations")
|
||||
op.execute("DROP TYPE IF EXISTS video_presentation_status")
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
"""Clean up Electric SQL artifacts (user, publication, replication slots)
|
||||
|
||||
Revision ID: 108
|
||||
Revises: 107
|
||||
|
||||
Removes leftover Electric SQL infrastructure that is no longer needed after
|
||||
the migration to Rocicorp Zero. Fully idempotent — safe on databases that
|
||||
never had Electric SQL set up (fresh installs).
|
||||
|
||||
Cleaned up:
|
||||
- Replication slots containing 'electric' (prevents unbounded WAL growth)
|
||||
- The 'electric_publication_default' publication
|
||||
- Default privileges, grants, and the 'electric' database user
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "108"
|
||||
down_revision: str | None = "107"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
slot RECORD;
|
||||
BEGIN
|
||||
-- 1. Drop inactive Electric replication slots (prevents WAL growth)
|
||||
FOR slot IN
|
||||
SELECT slot_name FROM pg_replication_slots
|
||||
WHERE slot_name LIKE '%electric%' AND active = false
|
||||
LOOP
|
||||
BEGIN
|
||||
PERFORM pg_drop_replication_slot(slot.slot_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop replication slot %: %', slot.slot_name, SQLERRM;
|
||||
END;
|
||||
END LOOP;
|
||||
|
||||
-- Warn about active Electric slots that cannot be safely dropped
|
||||
FOR slot IN
|
||||
SELECT slot_name FROM pg_replication_slots
|
||||
WHERE slot_name LIKE '%electric%' AND active = true
|
||||
LOOP
|
||||
RAISE WARNING 'Active Electric replication slot "%" was not dropped — drop it manually to stop WAL growth', slot.slot_name;
|
||||
END LOOP;
|
||||
|
||||
-- 2. Drop the Electric publication
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM pg_publication WHERE pubname = 'electric_publication_default') THEN
|
||||
DROP PUBLICATION electric_publication_default;
|
||||
END IF;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop publication electric_publication_default: %', SQLERRM;
|
||||
END;
|
||||
|
||||
-- 3. Revoke privileges and drop the Electric user
|
||||
IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'electric') THEN
|
||||
BEGIN
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public
|
||||
REVOKE SELECT ON TABLES FROM electric;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA public
|
||||
REVOKE SELECT ON SEQUENCES FROM electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke default privileges from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM electric;
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM electric;
|
||||
REVOKE USAGE ON SCHEMA public FROM electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke schema privileges from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
'REVOKE CONNECT ON DATABASE %I FROM electric',
|
||||
current_database()
|
||||
);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not revoke CONNECT from electric: %', SQLERRM;
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
REASSIGN OWNED BY electric TO CURRENT_USER;
|
||||
DROP ROLE electric;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'Could not drop role electric: %', SQLERRM;
|
||||
END;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
|
@ -21,6 +21,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
|
|
@ -37,13 +40,15 @@ _perf_log = get_perf_logger()
|
|||
# =============================================================================
|
||||
|
||||
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
|
||||
# used by the knowledge_base tool. Some connectors map to different document types.
|
||||
# used by the knowledge_base and web_search tools.
|
||||
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
|
||||
# the web_search tool; all others go to search_knowledge_base.
|
||||
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
||||
# Direct mappings (connector type == searchable type)
|
||||
# Live search connectors (handled by web_search tool)
|
||||
"TAVILY_API": "TAVILY_API",
|
||||
"SEARXNG_API": "SEARXNG_API",
|
||||
"LINKUP_API": "LINKUP_API",
|
||||
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
|
||||
# Local/indexed connectors (handled by search_knowledge_base tool)
|
||||
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
|
||||
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
|
||||
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
|
||||
|
|
@ -63,10 +68,11 @@ _CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
|||
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
|
||||
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
|
||||
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
# Composio connectors (unified to native document types).
|
||||
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
# Document types that don't come from SearchSourceConnector but should always be searchable
|
||||
|
|
@ -233,6 +239,7 @@ async def create_surfsense_deep_agent(
|
|||
available_document_types = await connector_service.get_available_document_types(
|
||||
search_space_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
_perf_log.info(
|
||||
|
|
@ -289,6 +296,69 @@ async def create_surfsense_deep_agent(
|
|||
]
|
||||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
||||
has_google_drive_connector = (
|
||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_google_drive_connector:
|
||||
google_drive_tools = [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_google_calendar_connector:
|
||||
calendar_tools = [
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
]
|
||||
modified_disabled_tools.extend(calendar_tools)
|
||||
|
||||
# Disable Gmail action tools if no Gmail connector is configured
|
||||
has_gmail_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_gmail_connector:
|
||||
gmail_tools = [
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
]
|
||||
modified_disabled_tools.extend(gmail_tools)
|
||||
|
||||
# Disable Jira action tools if no Jira connector is configured
|
||||
has_jira_connector = (
|
||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_jira_connector:
|
||||
jira_tools = [
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(jira_tools)
|
||||
|
||||
# Disable Confluence action tools if no Confluence connector is configured
|
||||
has_confluence_connector = (
|
||||
available_connectors is not None
|
||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_confluence_connector:
|
||||
confluence_tools = [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
]
|
||||
modified_disabled_tools.extend(confluence_tools)
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
|
|
@ -342,6 +412,7 @@ async def create_surfsense_deep_agent(
|
|||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
middleware=[DedupHITLToolCallsMiddleware()],
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"delete_google_drive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
|
@ -99,14 +99,8 @@ _TOOL_INSTRUCTIONS["search_knowledge_base"] = """
|
|||
- IMPORTANT: When searching for information (meetings, schedules, notes, tasks, etc.), ALWAYS search broadly
|
||||
across ALL sources first by omitting connectors_to_search. The user may store information in various places
|
||||
including calendar apps, note-taking apps (Obsidian, Notion), chat apps (Slack, Discord), and more.
|
||||
- IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data
|
||||
(e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call
|
||||
`search_knowledge_base` using live web connectors via `connectors_to_search`:
|
||||
["LINKUP_API", "TAVILY_API", "SEARXNG_API", "BAIDU_SEARCH_API"].
|
||||
- For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet
|
||||
access before attempting a live connector search.
|
||||
- If the live connectors return no relevant results, explain that live web sources did not return enough
|
||||
data and ask the user if they want you to retry with a refined query.
|
||||
- This tool searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
|
||||
For real-time web search (current events, news, live data), use the `web_search` tool instead.
|
||||
- FALLBACK BEHAVIOR: If the search returns no relevant results, you MAY then answer using your own
|
||||
general knowledge, but clearly indicate that no matching information was found in the knowledge base.
|
||||
- Only narrow to specific connectors if the user explicitly asks (e.g., "check my Slack" or "in my calendar").
|
||||
|
|
@ -138,6 +132,17 @@ _TOOL_INSTRUCTIONS["generate_podcast"] = """
|
|||
- After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes).
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["generate_video_presentation"] = """
|
||||
- generate_video_presentation: Generate a video presentation from provided content.
|
||||
- Use this when the user asks to create a video, presentation, slides, or slide deck.
|
||||
- Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation"
|
||||
- Args:
|
||||
- source_content: The text content to turn into a presentation. The more detailed, the better.
|
||||
- video_title: Optional title (default: "SurfSense Presentation")
|
||||
- user_prompt: Optional style instructions (e.g., "Make it technical and detailed")
|
||||
- After calling this tool, inform the user that generation has started and they will see the presentation when it's ready.
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["generate_report"] = """
|
||||
- generate_report: Generate or revise a structured Markdown report artifact.
|
||||
- WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable:
|
||||
|
|
@ -271,6 +276,24 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
|||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
"""
|
||||
|
||||
_TOOL_INSTRUCTIONS["web_search"] = """
|
||||
- web_search: Search the web for real-time information using all configured search engines.
|
||||
- Use this for current events, news, prices, weather, public facts, or any question requiring
|
||||
up-to-date information from the internet.
|
||||
- This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in
|
||||
parallel and merges the results.
|
||||
- IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data
|
||||
(e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call
|
||||
`web_search` instead of answering from memory.
|
||||
- For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet
|
||||
access before attempting a web search.
|
||||
- If the search returns no relevant results, explain that web sources did not return enough
|
||||
data and ask the user if they want you to retry with a refined query.
|
||||
- Args:
|
||||
- query: The search query - use specific, descriptive terms
|
||||
- top_k: Number of results to retrieve (default: 10, max: 50)
|
||||
"""
|
||||
|
||||
# Memory tool instructions have private and shared variants.
|
||||
# We store them keyed as "save_memory" / "recall_memory" with sub-keys.
|
||||
_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = {
|
||||
|
|
@ -401,7 +424,7 @@ _TOOL_EXAMPLES["search_knowledge_base"] = """
|
|||
- User: "Check my Obsidian notes for meeting notes"
|
||||
- Call: `search_knowledge_base(query="meeting notes", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
||||
- User: "search me current usd to inr rate"
|
||||
- Call: `search_knowledge_base(query="current USD to INR exchange rate", connectors_to_search=["LINKUP_API", "TAVILY_API", "SEARXNG_API", "BAIDU_SEARCH_API"])`
|
||||
- Call: `web_search(query="current USD to INR exchange rate")`
|
||||
- Then answer using the returned live web results with citations.
|
||||
"""
|
||||
|
||||
|
|
@ -426,6 +449,16 @@ _TOOL_EXAMPLES["generate_podcast"] = """
|
|||
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["generate_video_presentation"] = """
|
||||
- User: "Give me a presentation about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")`
|
||||
- User: "Create slides summarizing this conversation"
|
||||
- Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
|
||||
- User: "Make a video presentation about quantum computing"
|
||||
- First search: `search_knowledge_base(query="quantum computing")`
|
||||
- Then: `generate_video_presentation(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", video_title="Quantum Computing Explained")`
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["generate_report"] = """
|
||||
- User: "Generate a report about AI trends"
|
||||
- Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")`
|
||||
|
|
@ -471,11 +504,23 @@ _TOOL_EXAMPLES["generate_image"] = """
|
|||
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
|
||||
"""
|
||||
|
||||
_TOOL_EXAMPLES["web_search"] = """
|
||||
- User: "What's the current USD to INR exchange rate?"
|
||||
- Call: `web_search(query="current USD to INR exchange rate")`
|
||||
- Then answer using the returned web results with citations.
|
||||
- User: "What's the latest news about AI?"
|
||||
- Call: `web_search(query="latest AI news today")`
|
||||
- User: "What's the weather in New York?"
|
||||
- Call: `web_search(query="weather New York today")`
|
||||
"""
|
||||
|
||||
# All tool names that have prompt instructions (order matters for prompt readability)
|
||||
_ALL_TOOL_NAMES_ORDERED = [
|
||||
"search_surfsense_docs",
|
||||
"search_knowledge_base",
|
||||
"web_search",
|
||||
"generate_podcast",
|
||||
"generate_video_presentation",
|
||||
"generate_report",
|
||||
"link_preview",
|
||||
"display_image",
|
||||
|
|
@ -543,7 +588,7 @@ DISABLED TOOLS (by user):
|
|||
The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}.
|
||||
You do NOT have access to these tools and MUST NOT claim you can use them.
|
||||
If the user asks about a capability provided by a disabled tool, let them know the relevant tool
|
||||
is currently disabled and they can re-enable it from the tools menu (wrench icon) in the composer toolbar.
|
||||
is currently disabled and they can re-enable it.
|
||||
""")
|
||||
|
||||
parts.append("\n</tools>\n")
|
||||
|
|
@ -595,11 +640,10 @@ The documents you receive are structured like this:
|
|||
</document_content>
|
||||
</document>
|
||||
|
||||
**Live web search results (URL chunk IDs):**
|
||||
**Web search results (URL chunk IDs):**
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_id>TAVILY_API::Some Title::https://example.com/article</document_id>
|
||||
<document_type>TAVILY_API</document_type>
|
||||
<document_type>WEB_SEARCH</document_type>
|
||||
<title><![CDATA[Some web search result]]></title>
|
||||
<url><![CDATA[https://example.com/article]]></url>
|
||||
</document_metadata>
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Available tools:
|
|||
- search_knowledge_base: Search the user's personal knowledge base
|
||||
- search_surfsense_docs: Search Surfsense documentation for usage help
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_video_presentation: Generate video presentations with slides and narration
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
- link_preview: Fetch rich previews for URLs
|
||||
- display_image: Display images in chat
|
||||
|
|
@ -39,6 +40,7 @@ from .registry import (
|
|||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
|
|
@ -51,6 +53,7 @@ __all__ = [
|
|||
"create_display_image_tool",
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
"create_link_preview_tool",
|
||||
"create_recall_memory_tool",
|
||||
"create_save_memory_tool",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
"""Confluence tools for creating, updating, and deleting pages."""
|
||||
|
||||
from .create_page import create_create_confluence_page_tool
|
||||
from .delete_page import create_delete_confluence_page_tool
|
||||
from .update_page import create_update_confluence_page_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_confluence_page_tool",
|
||||
"create_delete_confluence_page_tool",
|
||||
"create_update_confluence_page_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_confluence_page(
|
||||
title: str,
|
||||
content: str | None = None,
|
||||
space_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Confluence.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Confluence page.
|
||||
|
||||
Args:
|
||||
title: Title of the page.
|
||||
content: Optional HTML/storage format content for the page body.
|
||||
space_id: Optional Confluence space ID to create the page in.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, page_id, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(f"create_confluence_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_creation",
|
||||
"action": {
|
||||
"tool": "create_confluence_page",
|
||||
"params": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_content = final_params.get("content", content) or ""
|
||||
final_space_id = final_params.get("space_id", space_id)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the page.",
|
||||
}
|
||||
|
||||
return create_confluence_page
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_confluence_page(
|
||||
page_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Confluence page.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_deletion",
|
||||
"action": {
|
||||
"tool": "delete_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the page.",
|
||||
}
|
||||
|
||||
return delete_confluence_page
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_confluence_page(
|
||||
page_title_or_id: str,
|
||||
new_title: str | None = None,
|
||||
new_content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page.
|
||||
|
||||
Use this tool when the user asks to modify or edit a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
new_title: Optional new title for the page.
|
||||
new_content: Optional new HTML/storage format content.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_update",
|
||||
"action": {
|
||||
"tool": "update_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_title = final_params.get("new_title", new_title) or current_title
|
||||
final_content = final_params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = final_params.get("version", current_version)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the page.",
|
||||
}
|
||||
|
||||
return update_confluence_page
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.trash_email import (
|
||||
create_trash_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.update_draft import (
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_gmail_draft(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a draft email in Gmail.
|
||||
|
||||
Use when the user asks to draft, compose, or prepare an email without
|
||||
sending it.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Draft an email to alice@example.com about the meeting"
|
||||
- "Compose a reply to Bob about the project update"
|
||||
"""
|
||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_creation",
|
||||
"action": {
|
||||
"tool": "create_gmail_draft",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return create_gmail_draft
|
||||
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
343
surfsense_backend/app/agents/new_chat/tools/gmail/send_email.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_gmail_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send an email via Gmail.
|
||||
|
||||
Use when the user explicitly asks to send an email. This sends the
|
||||
email immediately - it cannot be unsent.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- thread_id: Gmail thread ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Send an email to alice@example.com about the meeting"
|
||||
- "Email Bob the project update"
|
||||
"""
|
||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_send",
|
||||
"action": {
|
||||
"tool": "send_gmail_email",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while sending the email. Please try again.",
|
||||
}
|
||||
|
||||
return send_gmail_email
|
||||
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
337
surfsense_backend/app/agents/new_chat/tools/gmail/trash_email.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_trash_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def trash_gmail_email(
|
||||
email_subject_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move an email or draft to trash in Gmail.
|
||||
|
||||
Use when the user asks to delete, remove, or trash an email or draft.
|
||||
|
||||
Args:
|
||||
email_subject_or_id: The exact subject line or message ID of the
|
||||
email to trash (as it appears in the inbox).
|
||||
delete_from_kb: Whether to also remove the email from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Gmail and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the email subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
Examples:
|
||||
- "Delete the email about 'Meeting Cancelled'"
|
||||
- "Trash the email from Bob about the project"
|
||||
"""
|
||||
logger.info(
|
||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_trash",
|
||||
"action": {
|
||||
"tool": "trash_gmail_email",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_message_id = final_params.get("message_id", message_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the email. Please try again.",
|
||||
}
|
||||
|
||||
return trash_gmail_email
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_gmail_draft(
|
||||
draft_subject_or_id: str,
|
||||
body: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Gmail draft.
|
||||
|
||||
Use when the user asks to modify, edit, or add content to an existing
|
||||
email draft. This replaces the draft content with the new version.
|
||||
The user will be able to review and edit the content before it is applied.
|
||||
|
||||
If the user simply wants to "edit" a draft without specifying exact changes,
|
||||
generate the body yourself using your best understanding of the conversation
|
||||
context. The user will review and can freely edit the content in the approval
|
||||
card before confirming.
|
||||
|
||||
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
|
||||
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
|
||||
calendar events, or any other content type.
|
||||
|
||||
Args:
|
||||
draft_subject_or_id: The exact subject line of the draft to update
|
||||
(as it appears in Gmail drafts).
|
||||
body: The full updated body content for the draft. Generate this
|
||||
yourself based on the user's request and conversation context.
|
||||
to: Optional new recipient email address (keeps original if omitted).
|
||||
subject: Optional new subject line (keeps original if omitted).
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the draft subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Update the Kurseong Plan draft with the new itinerary details"
|
||||
- "Edit my draft about the project proposal and change the recipient"
|
||||
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
|
||||
"""
|
||||
logger.info(
|
||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_update",
|
||||
"action": {
|
||||
"tool": "update_gmail_draft",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", final_to_default)
|
||||
final_subject = final_params.get("subject", final_subject_default)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = final_params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return update_gmail_draft
|
||||
|
||||
|
||||
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
|
||||
"""Look up a draft's ID by its message ID via the Gmail API."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
gmail_service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||
return None
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.google_calendar.create_event import (
|
||||
create_create_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_calendar_event(
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Google Calendar.
|
||||
|
||||
Use when the user asks to schedule, create, or add a calendar event.
|
||||
Ask for event details if not provided.
|
||||
|
||||
Args:
|
||||
summary: The event title.
|
||||
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
|
||||
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
|
||||
description: Optional event description.
|
||||
location: Optional event location.
|
||||
attendees: Optional list of attendee email addresses.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Schedule a meeting with John tomorrow at 10am"
|
||||
- "Create a calendar event for the team standup"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_creation",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_start_datetime = final_params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = final_params.get("end_datetime", end_datetime)
|
||||
final_description = final_params.get("description", description)
|
||||
final_location = final_params.get("location", location)
|
||||
final_attendees = final_params.get("attendees", attendees)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the event. Please try again.",
|
||||
}
|
||||
|
||||
return create_calendar_event
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_calendar_event(
|
||||
event_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Google Calendar event.
|
||||
|
||||
Use when the user asks to delete, remove, or cancel a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to delete.
|
||||
delete_from_kb: Whether to also remove the event from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Calendar and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the team standup event"
|
||||
- "Cancel my dentist appointment on Friday"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_deletion",
|
||||
"action": {
|
||||
"tool": "delete_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return delete_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the event. Please try again.",
|
||||
}
|
||||
|
||||
return delete_calendar_event
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_calendar_event(
|
||||
event_title_or_id: str,
|
||||
new_summary: str | None = None,
|
||||
new_start_datetime: str | None = None,
|
||||
new_end_datetime: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_location: str | None = None,
|
||||
new_attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Google Calendar event.
|
||||
|
||||
Use when the user asks to modify, reschedule, or change a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to update.
|
||||
new_summary: New event title (if changing).
|
||||
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
|
||||
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
|
||||
new_description: New event description (if changing).
|
||||
new_location: New event location (if changing).
|
||||
new_attendees: New list of attendee email addresses (if changing).
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Reschedule the team standup to 3pm"
|
||||
- "Change the location of my dentist appointment"
|
||||
"""
|
||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_update",
|
||||
"action": {
|
||||
"tool": "update_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = final_params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = final_params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = final_params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = final_params.get("new_description", new_description)
|
||||
final_new_location = final_params.get("new_location", new_location)
|
||||
final_new_attendees = final_params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["start"] = {
|
||||
"dateTime": final_new_start_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_end_datetime is not None:
|
||||
tz = (
|
||||
context.get("timezone", "UTC")
|
||||
if isinstance(context, dict)
|
||||
else "UTC"
|
||||
)
|
||||
update_body["end"] = {
|
||||
"dateTime": final_new_end_datetime,
|
||||
"timeZone": tz,
|
||||
}
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the event. Please try again.",
|
||||
}
|
||||
|
||||
return update_calendar_event
|
||||
|
|
@ -32,13 +32,16 @@ def create_create_google_drive_file_tool(
|
|||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive.
|
||||
or spreadsheet in Google Drive. The user MUST specify a topic before
|
||||
you call this tool. If the request does not contain a topic (e.g.
|
||||
"create a drive doc" or "make a Google Sheet"), ask what the file
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. For google_doc, provide markdown text.
|
||||
For google_sheet, provide CSV-formatted text.
|
||||
content: Optional initial content. Generate from the user's topic.
|
||||
For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -55,8 +58,8 @@ def create_create_google_drive_file_tool(
|
|||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
- "Create a Google Doc with today's meeting notes"
|
||||
- "Create a spreadsheet for the 2026 budget"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
|
|
@ -84,6 +87,15 @@ def create_create_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
|
@ -154,14 +166,18 @@ def create_create_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -176,8 +192,7 @@ def create_create_google_drive_file_tool(
|
|||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -191,8 +206,22 @@ def create_create_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
|
|
@ -206,22 +235,65 @@ def create_create_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def create_delete_google_drive_file_tool(
|
|||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
|
|
@ -76,6 +75,18 @@ def create_delete_google_drive_file_tool(
|
|||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
|
|
@ -151,13 +162,17 @@ def create_delete_google_drive_file_tool(
|
|||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -170,7 +185,23 @@ def create_delete_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
|
|
@ -178,10 +209,26 @@ def create_delete_google_drive_file_tool(
|
|||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
|
|
|
|||
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
11
surfsense_backend/app/agents/new_chat/tools/jira/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Jira tools for creating, updating, and deleting issues."""
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_jira_issue_tool",
|
||||
"create_delete_jira_issue_tool",
|
||||
"create_update_jira_issue_tool",
|
||||
]
|
||||
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
242
surfsense_backend/app/agents/new_chat/tools/jira/create_issue.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_jira_issue(
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue in Jira.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Jira issue/ticket.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key (e.g. "PROJ", "ENG").
|
||||
summary: Short, descriptive issue title.
|
||||
issue_type: Issue type (default "Task"). Others: "Bug", "Story", "Epic".
|
||||
description: Optional description body for the issue.
|
||||
priority: Optional priority name (e.g. "High", "Medium", "Low").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, issue_key, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user declined. Do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Jira accounts need re-authentication.",
|
||||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_creation",
|
||||
"action": {
|
||||
"tool": "create_jira_issue",
|
||||
"params": {
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not created.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_project_key = final_params.get("project_key", project_key)
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_issue_type = final_params.get("issue_type", issue_type)
|
||||
final_description = final_params.get("description", description)
|
||||
final_priority = final_params.get("priority", priority)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
||||
if not final_project_key:
|
||||
return {"status": "error", "message": "A project must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Jira connector found."}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
api_result = await asyncio.to_thread(
|
||||
jira_client.create_issue,
|
||||
project_key=final_project_key,
|
||||
summary=final_summary,
|
||||
issue_type=final_issue_type,
|
||||
description=final_description,
|
||||
priority=final_priority,
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_key = api_result.get("key", "")
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{issue_key}"
|
||||
if jira_history._base_url and issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=issue_key,
|
||||
issue_identifier=issue_key,
|
||||
issue_title=final_summary,
|
||||
description=final_description,
|
||||
state="To Do",
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the issue.",
|
||||
}
|
||||
|
||||
return create_jira_issue
|
||||
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
209
surfsense_backend/app/agents/new_chat/tools/jira/delete_issue.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Jira issue.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_deletion",
|
||||
"action": {
|
||||
"tool": "delete_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not deleted.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Jira issue {final_issue_key} deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the issue.",
|
||||
}
|
||||
|
||||
return delete_jira_issue
|
||||
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
252
surfsense_backend/app/agents/new_chat/tools/jira/update_issue.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_jira_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_jira_issue(
|
||||
issue_title_or_key: str,
|
||||
new_summary: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue.
|
||||
|
||||
Use this tool when the user asks to modify, edit, or update a Jira issue.
|
||||
|
||||
Args:
|
||||
issue_title_or_key: The issue key (e.g. "PROJ-42") or title to identify the issue.
|
||||
new_summary: Optional new title/summary for the issue.
|
||||
new_description: Optional new description.
|
||||
new_priority: Optional new priority name.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message and ask user to verify.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Jira tool not properly configured."}
|
||||
|
||||
try:
|
||||
metadata_service = JiraToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, issue_title_or_key
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "jira",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_data = context["issue"]
|
||||
issue_key = issue_data["issue_id"]
|
||||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_update",
|
||||
"action": {
|
||||
"tool": "update_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not updated.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_summary = final_params.get("new_summary", new_summary)
|
||||
final_description = final_params.get("new_description", new_description)
|
||||
final_priority = final_params.get("new_priority", new_priority)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Jira connector is invalid.",
|
||||
}
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
if final_summary:
|
||||
fields["summary"] = final_summary
|
||||
if final_description is not None:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": final_description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if final_priority:
|
||||
fields["priority"] = {"name": final_priority}
|
||||
|
||||
if not fields:
|
||||
return {"status": "error", "message": "No changes specified."}
|
||||
|
||||
try:
|
||||
jira_history = JiraHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
jira_client = await jira_history._get_jira_client()
|
||||
await asyncio.to_thread(
|
||||
jira_client.update_issue, final_issue_key, fields
|
||||
)
|
||||
except Exception as api_err:
|
||||
if "status code 403" in str(api_err).lower():
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
issue_url = (
|
||||
f"{jira_history._base_url}/browse/{final_issue_key}"
|
||||
if jira_history._base_url and final_issue_key
|
||||
else ""
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.jira import JiraKBSyncService
|
||||
|
||||
kb_service = JiraKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
issue_id=final_issue_key,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_key": final_issue_key,
|
||||
"issue_url": issue_url,
|
||||
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Jira issue: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the issue.",
|
||||
}
|
||||
|
||||
return update_jira_issue
|
||||
|
|
@ -9,6 +9,7 @@ This module provides:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
|
@ -19,15 +20,14 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
# Connectors that call external live-search APIs (no local DB / embedding needed).
|
||||
# These are never filtered by available_document_types.
|
||||
# Connectors that call external live-search APIs. These are handled by the
|
||||
# ``web_search`` tool and must be excluded from knowledge-base searches.
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
|
@ -61,7 +61,7 @@ def _is_degenerate_query(query: str) -> bool:
|
|||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
document_type: str | list[str] | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
|
|
@ -84,14 +84,22 @@ async def _browse_recent_documents(
|
|||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
|
|
@ -190,20 +198,12 @@ _ALL_CONNECTORS: list[str] = [
|
|||
"GOOGLE_DRIVE_FILE",
|
||||
"DISCORD_CONNECTOR",
|
||||
"AIRTABLE_CONNECTOR",
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
"LUMA_CONNECTOR",
|
||||
"NOTE",
|
||||
"BOOKSTACK_CONNECTOR",
|
||||
"CRAWLED_URL",
|
||||
"CIRCLEBACK",
|
||||
"OBSIDIAN_CONNECTOR",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"COMPOSIO_GMAIL_CONNECTOR",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
]
|
||||
|
||||
# Human-readable descriptions for each connector type
|
||||
|
|
@ -227,20 +227,12 @@ CONNECTOR_DESCRIPTIONS: dict[str, str] = {
|
|||
"GOOGLE_DRIVE_FILE": "Google Drive files and documents (personal cloud storage)",
|
||||
"DISCORD_CONNECTOR": "Discord server conversations and shared content (personal community)",
|
||||
"AIRTABLE_CONNECTOR": "Airtable records, tables, and database content (personal data)",
|
||||
"TAVILY_API": "Tavily web search API results (real-time web search)",
|
||||
"SEARXNG_API": "SearxNG search API results (privacy-focused web search)",
|
||||
"LINKUP_API": "Linkup search API results (web search)",
|
||||
"BAIDU_SEARCH_API": "Baidu search API results (Chinese web search)",
|
||||
"LUMA_CONNECTOR": "Luma events and meetings",
|
||||
"WEBCRAWLER_CONNECTOR": "Webpages indexed by SurfSense (personally selected websites)",
|
||||
"CRAWLED_URL": "Webpages indexed by SurfSense (personally selected websites)",
|
||||
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
|
||||
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
|
||||
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "Google Drive files via Composio (personal cloud storage)",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "Gmail emails via Composio (personal emails)",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events via Composio (personal calendar)",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -268,14 +260,15 @@ def _normalize_connectors(
|
|||
valid_set = (
|
||||
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
|
||||
)
|
||||
valid_set -= _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
if not connectors_to_search:
|
||||
# Search all available connectors if none specified
|
||||
return (
|
||||
base = (
|
||||
list(available_connectors)
|
||||
if available_connectors
|
||||
else list(_ALL_CONNECTORS)
|
||||
)
|
||||
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
|
||||
normalized: list[str] = []
|
||||
for raw in connectors_to_search:
|
||||
|
|
@ -302,15 +295,14 @@ def _normalize_connectors(
|
|||
out.append(c)
|
||||
|
||||
# Fallback to all available if nothing matched
|
||||
return (
|
||||
out
|
||||
if out
|
||||
else (
|
||||
if not out:
|
||||
base = (
|
||||
list(available_connectors)
|
||||
if available_connectors
|
||||
else list(_ALL_CONNECTORS)
|
||||
)
|
||||
)
|
||||
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -361,6 +353,20 @@ def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
|||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"message_id",
|
||||
"thread_id",
|
||||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
|
|
@ -479,7 +485,6 @@ def format_documents_for_context(
|
|||
# a numeric chunk_id (the numeric IDs are meaningless auto-incremented counters).
|
||||
live_search_connectors = {
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
|
@ -490,7 +495,10 @@ def format_documents_for_context(
|
|||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
metadata_clean = {
|
||||
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
|
||||
}
|
||||
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
doc_lines: list[str] = [
|
||||
|
|
@ -623,12 +631,15 @@ async def search_knowledge_base_async(
|
|||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
# --- Optimization 1: skip local connectors that have zero indexed documents ---
|
||||
# --- Optimization 1: skip connectors that have zero indexed documents ---
|
||||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [
|
||||
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
|
||||
c
|
||||
for c in connectors
|
||||
if c in doc_types_set
|
||||
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
|
|
@ -664,9 +675,14 @@ async def search_knowledge_base_async(
|
|||
"[kb_search] degenerate query %r detected - falling back to recency browse",
|
||||
query,
|
||||
)
|
||||
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
if not local_connectors:
|
||||
local_connectors = [None] # type: ignore[list-item]
|
||||
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
||||
|
||||
expanded_browse = []
|
||||
for c in browse_connectors:
|
||||
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
|
||||
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
|
||||
else:
|
||||
expanded_browse.append(c)
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
|
|
@ -677,7 +693,7 @@ async def search_knowledge_base_async(
|
|||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in local_connectors
|
||||
for c in expanded_browse
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
|
|
@ -702,66 +718,20 @@ async def search_knowledge_base_async(
|
|||
)
|
||||
return result
|
||||
|
||||
# Specs for live-search connectors (external APIs, no local DB/embedding).
|
||||
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"SEARXNG_API": ("search_searxng", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
}
|
||||
|
||||
# --- Optimization 2: compute the query embedding once, share across all local searches ---
|
||||
precomputed_embedding: list[float] | None = None
|
||||
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
|
||||
if has_local_connectors:
|
||||
from app.config import config as app_config
|
||||
from app.config import config as app_config
|
||||
|
||||
t_embed = time.perf_counter()
|
||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
||||
perf.info(
|
||||
"[kb_search] shared embedding computed in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
t_embed = time.perf_counter()
|
||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
||||
perf.info(
|
||||
"[kb_search] shared embedding computed in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
is_live = connector in _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
if is_live:
|
||||
spec = live_connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
|
|
@ -839,6 +809,10 @@ async def search_knowledge_base_async(
|
|||
|
||||
deduplicated.append(doc)
|
||||
|
||||
# Sort by RRF score so the most relevant documents from ANY connector
|
||||
# appear first, preventing budget truncation from hiding top results.
|
||||
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
|
||||
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
|
|
@ -967,7 +941,9 @@ Focus searches on these types for best results."""
|
|||
# This is what the LLM sees when deciding whether/how to use the tool
|
||||
dynamic_description = f"""Search the user's personal knowledge base for relevant information.
|
||||
|
||||
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
|
||||
Use this tool to find documents, notes, files, web pages, and other content the user has indexed.
|
||||
This searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
|
||||
For real-time web search (current events, news, live data), use the `web_search` tool instead.
|
||||
|
||||
IMPORTANT:
|
||||
- Always craft specific, descriptive search queries using natural language keywords.
|
||||
|
|
@ -977,9 +953,6 @@ IMPORTANT:
|
|||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
|
||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
||||
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
|
||||
- For real-time/public web queries (e.g., current exchange rates, stock prices, breaking news, weather),
|
||||
explicitly include live web connectors in `connectors_to_search`, prioritizing:
|
||||
["LINKUP_API", "TAVILY_API", "SEARXNG_API", "BAIDU_SEARCH_API"].
|
||||
|
||||
## Available connector enums for `connectors_to_search`
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,13 @@ def create_create_linear_issue_tool(
|
|||
"""Create a new issue in Linear.
|
||||
|
||||
Use this tool when the user explicitly asks to create, add, or file
|
||||
a new issue / ticket / task in Linear.
|
||||
a new issue / ticket / task in Linear. The user MUST describe the issue
|
||||
before you call this tool. If the request is vague, ask what the issue
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: Short, descriptive issue title.
|
||||
description: Optional markdown body for the issue.
|
||||
title: Short, descriptive issue title. Infer from the user's request.
|
||||
description: Optional markdown body for the issue. Generate from context.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -57,9 +59,9 @@ def create_create_linear_issue_tool(
|
|||
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Linear issue titled 'Fix login bug'"
|
||||
- "Add a ticket for the payment timeout problem"
|
||||
- "File an issue about the broken search feature"
|
||||
- "Create a Linear issue for the login bug"
|
||||
- "File a ticket about the payment timeout problem"
|
||||
- "Add an issue for the broken search feature"
|
||||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
|
|
@ -82,6 +84,15 @@ def create_create_linear_issue_tool(
|
|||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,12 +226,36 @@ def create_create_linear_issue_tool(
|
|||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": result.get("message"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ def create_delete_linear_issue_tool(
|
|||
- If status is "not_found", inform the user conversationally using the exact message
|
||||
provided. Do NOT treat this as an error. Simply relay the message and ask the user
|
||||
to verify the issue title or identifier, or check if it has been indexed.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Fix login bug' Linear issue"
|
||||
- "Archive ENG-42"
|
||||
|
|
@ -91,6 +90,14 @@ def create_delete_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,14 @@ def create_update_linear_issue_tool(
|
|||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for update context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
|
|
|
|||
|
|
@ -33,17 +33,21 @@ def create_create_notion_page_tool(
|
|||
@tool
|
||||
async def create_notion_page(
|
||||
title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Notion with the given title and content.
|
||||
|
||||
Use this tool when the user asks you to create, save, or publish
|
||||
something to Notion. The page will be created in the user's
|
||||
configured Notion workspace.
|
||||
configured Notion workspace. The user MUST specify a topic before you
|
||||
call this tool. If the request does not contain a topic (e.g. "create a
|
||||
notion page"), ask what the page should be about. Never call this tool
|
||||
without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: The title of the Notion page.
|
||||
content: The markdown content for the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content for the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's topic.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -58,8 +62,8 @@ def create_create_notion_page_tool(
|
|||
and move on. Do NOT troubleshoot or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Notion page titled 'Meeting Notes' with content 'Discussed project timeline'"
|
||||
- "Save this to Notion with title 'Research Summary'"
|
||||
- "Create a Notion page about our Q2 roadmap"
|
||||
- "Save a summary of today's discussion to Notion"
|
||||
"""
|
||||
logger.info(f"create_notion_page called: title='{title}'")
|
||||
|
||||
|
|
@ -85,6 +89,15 @@ def create_create_notion_page_tool(
|
|||
"message": context["error"],
|
||||
}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Notion accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "notion",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
|
|
@ -215,6 +228,34 @@ def create_create_notion_page_tool(
|
|||
logger.info(
|
||||
f"create_page result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.notion import NotionKBSyncService
|
||||
|
||||
kb_service = NotionKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=result.get("page_id"),
|
||||
page_title=result.get("title", final_title),
|
||||
page_url=result.get("url"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
result["message"] = result.get("message", "") + kb_message_suffix
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -95,8 +95,19 @@ def create_delete_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
connector_id_from_context = account.get("id")
|
||||
document_id = context.get("document_id")
|
||||
|
||||
logger.info(
|
||||
|
|
@ -262,6 +273,18 @@ def create_delete_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error deleting Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -33,16 +33,19 @@ def create_update_notion_page_tool(
|
|||
@tool
|
||||
async def update_notion_page(
|
||||
page_title: str,
|
||||
content: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Notion page by appending new content.
|
||||
|
||||
Use this tool when the user asks you to add content to, modify, or update
|
||||
a Notion page. The new content will be appended to the existing page content.
|
||||
The user MUST specify what to add before you call this tool. If the
|
||||
request is vague, ask what content they want added.
|
||||
|
||||
Args:
|
||||
page_title: The title of the Notion page to update.
|
||||
content: The markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
content: Optional markdown content to append to the page body (supports headings, lists, paragraphs).
|
||||
Generate this yourself based on the user's request.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
|
|
@ -60,10 +63,9 @@ def create_update_notion_page_tool(
|
|||
Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]"
|
||||
Do NOT treat this as an error. Do NOT invent information. Simply relay the message and
|
||||
ask the user to verify the page title or check if it's been indexed.
|
||||
|
||||
Examples:
|
||||
- "Add 'New meeting notes from today' to the 'Meeting Notes' Notion page"
|
||||
- "Append the following to the 'Project Plan' Notion page: '# Status Update\n\nCompleted phase 1'"
|
||||
- "Add today's meeting notes to the 'Meeting Notes' Notion page"
|
||||
- "Update the 'Project Plan' page with a status update on phase 1"
|
||||
"""
|
||||
logger.info(
|
||||
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
|
||||
|
|
@ -107,6 +109,17 @@ def create_update_notion_page_tool(
|
|||
"message": error_msg,
|
||||
}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Notion account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
}
|
||||
|
||||
page_id = context.get("page_id")
|
||||
document_id = context.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
|
@ -261,6 +274,18 @@ def create_update_notion_page_tool(
|
|||
raise
|
||||
|
||||
logger.error(f"Error updating Notion page: {e}", exc_info=True)
|
||||
error_str = str(e).lower()
|
||||
if isinstance(e, NotionAPIError) and (
|
||||
"401" in error_str or "unauthorized" in error_str
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": str(e),
|
||||
"connector_id": connector_id_from_context
|
||||
if "connector_id_from_context" in dir()
|
||||
else None,
|
||||
"connector_type": "notion",
|
||||
}
|
||||
if isinstance(e, ValueError | NotionAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,60 +4,15 @@ Podcast generation tool for the SurfSense agent.
|
|||
This module provides a factory function for creating the generate_podcast tool
|
||||
that submits a Celery task for background podcast generation. The frontend
|
||||
polls for completion and auto-updates when the podcast is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one podcast can be generated at a time per search space
|
||||
- Uses Redis to track active podcast tasks
|
||||
- Returns a friendly message if a podcast is already being generated
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
|
||||
# Redis connection for tracking active podcast tasks
|
||||
# Defaults to the Celery broker when REDIS_APP_URL is not set
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for podcast task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _redis_key(search_space_id: int) -> str:
|
||||
return f"podcast:generating:{search_space_id}"
|
||||
|
||||
|
||||
def get_generating_podcast_id(search_space_id: int) -> int | None:
|
||||
"""Get the podcast ID currently being generated for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
value = client.get(_redis_key(search_space_id))
|
||||
return int(value) if value else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def set_generating_podcast(search_space_id: int, podcast_id: int) -> None:
|
||||
"""Mark a podcast as currently generating for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.setex(_redis_key(search_space_id), 1800, str(podcast_id))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}"
|
||||
)
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -109,18 +64,6 @@ def create_generate_podcast_tool(
|
|||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
generating_podcast_id = get_generating_podcast_id(search_space_id)
|
||||
if generating_podcast_id:
|
||||
print(
|
||||
f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}"
|
||||
)
|
||||
return {
|
||||
"status": PodcastStatus.GENERATING.value,
|
||||
"podcast_id": generating_podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": "A podcast is already being generated. Please wait for it to complete.",
|
||||
}
|
||||
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
|
|
@ -142,8 +85,6 @@ def create_generate_podcast_tool(
|
|||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
set_generating_podcast(search_space_id, podcast.id)
|
||||
|
||||
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -45,12 +45,33 @@ from langchain_core.tools import BaseTool
|
|||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .confluence import (
|
||||
create_create_confluence_page_tool,
|
||||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
|
|
@ -73,6 +94,8 @@ from .shared_memory import (
|
|||
create_save_shared_memory_tool,
|
||||
)
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
||||
# =============================================================================
|
||||
# Tool Definition
|
||||
|
|
@ -135,6 +158,17 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["search_space_id", "db_session", "thread_id"],
|
||||
),
|
||||
# Video presentation generation tool
|
||||
ToolDefinition(
|
||||
name="generate_video_presentation",
|
||||
description="Generate a video presentation with slides and narration from provided content",
|
||||
factory=lambda deps: create_generate_video_presentation_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
thread_id=deps["thread_id"],
|
||||
),
|
||||
requires=["search_space_id", "db_session", "thread_id"],
|
||||
),
|
||||
# Report generation tool (inline, short-lived sessions for DB ops)
|
||||
# Supports internal KB search via source_strategy so the agent doesn't
|
||||
# need to call search_knowledge_base separately before generating.
|
||||
|
|
@ -186,7 +220,16 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=[], # firecrawl_api_key is optional
|
||||
),
|
||||
# Note: write_todos is now provided by TodoListMiddleware from deepagents
|
||||
# Web search tool — real-time web search via SearXNG + user-configured engines
|
||||
ToolDefinition(
|
||||
name="web_search",
|
||||
description="Search the web for real-time information using configured search engines",
|
||||
factory=lambda deps: create_web_search_tool(
|
||||
search_space_id=deps.get("search_space_id"),
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
),
|
||||
requires=[],
|
||||
),
|
||||
# Surfsense documentation search tool
|
||||
ToolDefinition(
|
||||
name="search_surfsense_docs",
|
||||
|
|
@ -235,7 +278,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues (WIP - hidden from UI)
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
|
|
@ -246,8 +290,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
|
|
@ -258,8 +300,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
|
|
@ -270,11 +310,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages (WIP - hidden from UI)
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_notion_page",
|
||||
|
|
@ -285,8 +324,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -297,8 +334,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -309,11 +344,10 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files (WIP - hidden from UI)
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
# Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_google_drive_file",
|
||||
|
|
@ -324,8 +358,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -336,8 +368,152 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
enabled_by_default=False,
|
||||
hidden=True,
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
factory=lambda deps: create_create_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
description="Update an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_update_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
description="Delete an existing indexed Google Calendar event",
|
||||
factory=lambda deps: create_delete_calendar_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
factory=lambda deps: create_create_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
description="Send an email via Gmail",
|
||||
factory=lambda deps: create_send_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
description="Move an indexed email to trash in Gmail",
|
||||
factory=lambda deps: create_trash_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
description="Update an existing Gmail draft",
|
||||
factory=lambda deps: create_update_gmail_draft_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_confluence_page",
|
||||
description="Create a new page in the user's Confluence space",
|
||||
factory=lambda deps: create_create_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
description="Update an existing indexed Confluence page",
|
||||
factory=lambda deps: create_update_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
description="Delete an existing indexed Confluence page",
|
||||
factory=lambda deps: create_delete_confluence_page_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
Video presentation generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a factory function for creating the generate_video_presentation
|
||||
tool that submits a Celery task for background video presentation generation.
|
||||
The frontend polls for completion and auto-updates when the presentation is ready.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_video_presentation tool with injected dependencies.
|
||||
|
||||
Pre-creates video presentation record with pending status so the ID is available
|
||||
immediately for frontend polling.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_video_presentation(
|
||||
source_content: str,
|
||||
video_title: str = "SurfSense Presentation",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a video presentation from the provided content.
|
||||
|
||||
Use this tool when the user asks to create a video, presentation, slides, or slide deck.
|
||||
|
||||
Args:
|
||||
source_content: The text content to turn into a presentation.
|
||||
video_title: Title for the presentation (default: "SurfSense Presentation")
|
||||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
db_session.add(video_pres)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(video_pres)
|
||||
|
||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
||||
generate_video_presentation_task,
|
||||
)
|
||||
|
||||
task = generate_video_presentation_task.delay(
|
||||
video_presentation_id=video_pres.id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": VideoPresentationStatus.PENDING.value,
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_title,
|
||||
"message": "Video presentation generation started. This may take a few minutes.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[generate_video_presentation] Error: {error_message}")
|
||||
return {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
"title": video_title,
|
||||
"video_presentation_id": None,
|
||||
}
|
||||
|
||||
return generate_video_presentation
|
||||
247
surfsense_backend/app/agents/new_chat/tools/web_search.py
Normal file
247
surfsense_backend/app/agents/new_chat/tools/web_search.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Web search tool for the SurfSense agent.
|
||||
|
||||
Provides a unified tool for real-time web searches that dispatches to all
|
||||
configured search engines: the platform SearXNG instance (always available)
|
||||
plus any user-configured live-search connectors (Tavily, Linkup, Baidu).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
_LIVE_CONNECTOR_SPECS: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
}
|
||||
|
||||
_CONNECTOR_LABELS: dict[str, str] = {
|
||||
"TAVILY_API": "Tavily",
|
||||
"LINKUP_API": "Linkup",
|
||||
"BAIDU_SEARCH_API": "Baidu",
|
||||
}
|
||||
|
||||
|
||||
class WebSearchInput(BaseModel):
|
||||
"""Input schema for the web_search tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query to look up on the web. Use specific, descriptive terms.",
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10, max: 50).",
|
||||
)
|
||||
|
||||
|
||||
def _format_web_results(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = 50_000,
|
||||
) -> str:
|
||||
"""Format web search results into XML suitable for the LLM context."""
|
||||
if not documents:
|
||||
return "No web search results found."
|
||||
|
||||
parts: list[str] = []
|
||||
total_chars = 0
|
||||
|
||||
for doc in documents:
|
||||
doc_info = doc.get("document") or {}
|
||||
metadata = doc_info.get("metadata") or {}
|
||||
title = doc_info.get("title") or "Web Result"
|
||||
url = metadata.get("url") or ""
|
||||
content = (doc.get("content") or "").strip()
|
||||
source = metadata.get("document_type") or doc.get("source") or "WEB_SEARCH"
|
||||
if not content:
|
||||
continue
|
||||
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
doc_xml = "\n".join(
|
||||
[
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_type>{source}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"<document_content>",
|
||||
f" <chunk id='{url}'><![CDATA[{content}]]></chunk>",
|
||||
"</document_content>",
|
||||
"</document>",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if total_chars + len(doc_xml) > max_chars:
|
||||
parts.append("<!-- Output truncated to fit context window -->")
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += len(doc_xml)
|
||||
|
||||
return "\n".join(parts).strip() or "No web search results found."
|
||||
|
||||
|
||||
async def _search_live_connector(
|
||||
connector: str,
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
top_k: int,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Dispatch a single live-search connector (Tavily / Linkup / Baidu)."""
|
||||
perf = get_perf_logger()
|
||||
spec = _LIVE_CONNECTOR_SPECS.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
|
||||
method_name, _includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as session:
|
||||
svc = ConnectorService(session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[web_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[web_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
|
||||
def create_web_search_tool(
|
||||
search_space_id: int | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Factory for the ``web_search`` tool.
|
||||
|
||||
Dispatches in parallel to the platform SearXNG instance and any
|
||||
user-configured live-search connectors (Tavily, Linkup, Baidu).
|
||||
"""
|
||||
active_live_connectors: list[str] = []
|
||||
if available_connectors:
|
||||
active_live_connectors = [
|
||||
c for c in available_connectors if c in _LIVE_SEARCH_CONNECTORS
|
||||
]
|
||||
|
||||
engine_names = ["SearXNG (platform default)"]
|
||||
engine_names.extend(_CONNECTOR_LABELS.get(c, c) for c in active_live_connectors)
|
||||
engines_summary = ", ".join(engine_names)
|
||||
|
||||
description = (
|
||||
"Search the web for real-time information. "
|
||||
"Use this for current events, news, prices, weather, public facts, or any "
|
||||
"question that requires up-to-date information from the internet.\n\n"
|
||||
f"Active search engines: {engines_summary}.\n"
|
||||
"All configured engines are queried in parallel and results are merged."
|
||||
)
|
||||
|
||||
_search_space_id = search_space_id
|
||||
_active_live = active_live_connectors
|
||||
|
||||
async def _web_search_impl(query: str, top_k: int = 10) -> str:
|
||||
from app.services import web_search_service
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
clamped_top_k = min(max(1, top_k), 50)
|
||||
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
|
||||
|
||||
if web_search_service.is_available():
|
||||
|
||||
async def _searxng() -> list[dict[str, Any]]:
|
||||
async with semaphore:
|
||||
_result_obj, docs = await web_search_service.search(
|
||||
query=query,
|
||||
top_k=clamped_top_k,
|
||||
)
|
||||
return docs
|
||||
|
||||
tasks.append(asyncio.ensure_future(_searxng()))
|
||||
|
||||
if _search_space_id is not None:
|
||||
for connector in _active_live:
|
||||
tasks.append(
|
||||
asyncio.ensure_future(
|
||||
_search_live_connector(
|
||||
connector=connector,
|
||||
query=query,
|
||||
search_space_id=_search_space_id,
|
||||
top_k=clamped_top_k,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return "Web search is not available — no search engines are configured."
|
||||
|
||||
results_lists = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
for result in results_lists:
|
||||
if isinstance(result, BaseException):
|
||||
perf.warning("[web_search] a search engine failed: %s", result)
|
||||
continue
|
||||
all_documents.extend(result)
|
||||
|
||||
seen_urls: set[str] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
for doc in all_documents:
|
||||
url = ((doc.get("document") or {}).get("metadata") or {}).get("url", "")
|
||||
if url and url in seen_urls:
|
||||
continue
|
||||
if url:
|
||||
seen_urls.add(url)
|
||||
deduplicated.append(doc)
|
||||
|
||||
formatted = _format_web_results(deduplicated)
|
||||
|
||||
perf.info(
|
||||
"[web_search] query=%r engines=%d results=%d deduped=%d chars=%d in %.3fs",
|
||||
query[:60],
|
||||
len(tasks),
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(formatted),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return formatted
|
||||
|
||||
return StructuredTool(
|
||||
name="web_search",
|
||||
description=description,
|
||||
coroutine=_web_search_impl,
|
||||
args_schema=WebSearchInput,
|
||||
)
|
||||
10
surfsense_backend/app/agents/video_presentation/__init__.py
Normal file
10
surfsense_backend/app/agents/video_presentation/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Video Presentation LangGraph Agent.
|
||||
|
||||
This module defines a graph for generating video presentations
|
||||
from source content, similar to the podcaster agent but producing
|
||||
slide-based video presentations with TTS narration.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
__all__ = ["graph"]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""Define the configurable parameters for the video presentation agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the video presentation agent."""
|
||||
|
||||
video_title: str
|
||||
search_space_id: int
|
||||
user_prompt: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
39
surfsense_backend/app/agents/video_presentation/graph.py
Normal file
39
surfsense_backend/app/agents/video_presentation/graph.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import (
|
||||
assign_slide_themes,
|
||||
create_presentation_slides,
|
||||
create_slide_audio,
|
||||
generate_slide_scene_codes,
|
||||
)
|
||||
from .state import State
|
||||
|
||||
|
||||
def build_graph():
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
workflow.add_node("create_presentation_slides", create_presentation_slides)
|
||||
workflow.add_node("create_slide_audio", create_slide_audio)
|
||||
workflow.add_node("assign_slide_themes", assign_slide_themes)
|
||||
workflow.add_node("generate_slide_scene_codes", generate_slide_scene_codes)
|
||||
|
||||
# Fan-out: after slides are parsed, run audio generation and theme
|
||||
# assignment in parallel (themes only need slide metadata, not audio).
|
||||
workflow.add_edge("__start__", "create_presentation_slides")
|
||||
workflow.add_edge("create_presentation_slides", "create_slide_audio")
|
||||
workflow.add_edge("create_presentation_slides", "assign_slide_themes")
|
||||
|
||||
# Fan-in: scene code generation waits for both audio and themes.
|
||||
workflow.add_edge("create_slide_audio", "generate_slide_scene_codes")
|
||||
workflow.add_edge("assign_slide_themes", "generate_slide_scene_codes")
|
||||
|
||||
workflow.add_edge("generate_slide_scene_codes", "__end__")
|
||||
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Video Presentation"
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
graph = build_graph()
|
||||
580
surfsense_backend/app/agents/video_presentation/nodes.py
Normal file
580
surfsense_backend/app/agents/video_presentation/nodes.py
Normal file
|
|
@ -0,0 +1,580 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from litellm import aspeech
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import (
|
||||
DEFAULT_DURATION_IN_FRAMES,
|
||||
FPS,
|
||||
REFINE_SCENE_SYSTEM_PROMPT,
|
||||
REMOTION_SCENE_SYSTEM_PROMPT,
|
||||
THEME_PRESETS,
|
||||
build_scene_generation_user_prompt,
|
||||
build_theme_assignment_user_prompt,
|
||||
get_slide_generation_prompt,
|
||||
get_theme_assignment_system_prompt,
|
||||
pick_theme_and_mode_fallback,
|
||||
)
|
||||
from .state import (
|
||||
PresentationSlides,
|
||||
SlideAudioResult,
|
||||
SlideContent,
|
||||
SlideSceneCode,
|
||||
State,
|
||||
)
|
||||
from .utils import get_voice_for_provider
|
||||
|
||||
MAX_REFINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def create_presentation_slides(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Parse source content into structured presentation slides using LLM."""
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
user_prompt = configuration.user_prompt
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No LLM configured for search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
prompt = get_slide_generation_prompt(user_prompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(
|
||||
content=f"<source_content>{state.source_content}</source_content>"
|
||||
),
|
||||
]
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
try:
|
||||
presentation = PresentationSlides.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
try:
|
||||
content = llm_response.content
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
parsed_data = json.loads(json_str)
|
||||
presentation = PresentationSlides.model_validate(parsed_data)
|
||||
print("Successfully parsed presentation slides using fallback approach")
|
||||
else:
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
raise
|
||||
|
||||
return {"slides": presentation.slides}
|
||||
|
||||
|
||||
async def create_slide_audio(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Generate TTS audio for each slide.
|
||||
|
||||
Each slide's speaker_transcripts are generated as individual TTS chunks,
|
||||
then concatenated with ffmpeg (matching the POC in RemotionTets/api/tts).
|
||||
"""
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
temp_dir = Path("temp_audio")
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
output_dir = Path("video_presentation_audio")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
slides = state.slides or []
|
||||
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id=0)
|
||||
ext = "wav" if app_config.TTS_SERVICE == "local/kokoro" else "mp3"
|
||||
|
||||
async def _generate_tts_chunk(text: str, chunk_path: str) -> str:
|
||||
"""Generate a single TTS chunk and write it to *chunk_path*."""
|
||||
if app_config.TTS_SERVICE == "local/kokoro":
|
||||
kokoro_service = await get_kokoro_tts_service(lang_code="a")
|
||||
await kokoro_service.generate_speech(
|
||||
text=text,
|
||||
voice=voice,
|
||||
speed=1.0,
|
||||
output_path=chunk_path,
|
||||
)
|
||||
else:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.TTS_SERVICE,
|
||||
"api_key": app_config.TTS_SERVICE_API_KEY,
|
||||
"voice": voice,
|
||||
"input": text,
|
||||
"max_retries": 2,
|
||||
"timeout": 600,
|
||||
}
|
||||
if app_config.TTS_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.TTS_SERVICE_API_BASE
|
||||
|
||||
response = await aspeech(**kwargs)
|
||||
with open(chunk_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return chunk_path
|
||||
|
||||
async def _concat_with_ffmpeg(chunk_paths: list[str], output_file: str) -> None:
|
||||
"""Concatenate multiple audio chunks into one file using async ffmpeg."""
|
||||
ffmpeg = FFmpeg().option("y")
|
||||
for chunk in chunk_paths:
|
||||
ffmpeg = ffmpeg.input(chunk)
|
||||
|
||||
filter_parts = [f"[{i}:0]" for i in range(len(chunk_paths))]
|
||||
filter_str = (
|
||||
"".join(filter_parts) + f"concat=n={len(chunk_paths)}:v=0:a=1[outa]"
|
||||
)
|
||||
ffmpeg = ffmpeg.option("filter_complex", filter_str)
|
||||
ffmpeg = ffmpeg.output(output_file, map="[outa]")
|
||||
await ffmpeg.execute()
|
||||
|
||||
async def generate_audio_for_slide(slide: SlideContent) -> SlideAudioResult:
|
||||
has_transcripts = (
|
||||
slide.speaker_transcripts and len(slide.speaker_transcripts) > 0
|
||||
)
|
||||
|
||||
if not has_transcripts:
|
||||
print(
|
||||
f"Slide {slide.slide_number}: no speaker_transcripts, "
|
||||
f"using default duration ({DEFAULT_DURATION_IN_FRAMES} frames)"
|
||||
)
|
||||
return SlideAudioResult(
|
||||
slide_number=slide.slide_number,
|
||||
audio_file="",
|
||||
duration_seconds=DEFAULT_DURATION_IN_FRAMES / FPS,
|
||||
duration_in_frames=DEFAULT_DURATION_IN_FRAMES,
|
||||
)
|
||||
|
||||
output_file = str(output_dir / f"{session_id}_slide_{slide.slide_number}.{ext}")
|
||||
|
||||
chunk_paths: list[str] = []
|
||||
try:
|
||||
chunk_paths = [
|
||||
str(
|
||||
temp_dir
|
||||
/ f"{session_id}_slide_{slide.slide_number}_chunk_{i}.{ext}"
|
||||
)
|
||||
for i in range(len(slide.speaker_transcripts))
|
||||
]
|
||||
|
||||
for i, text in enumerate(slide.speaker_transcripts):
|
||||
print(
|
||||
f" Slide {slide.slide_number} chunk {i + 1}/"
|
||||
f"{len(slide.speaker_transcripts)}: "
|
||||
f'"{text[:60]}..."'
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_generate_tts_chunk(text, path)
|
||||
for text, path in zip(
|
||||
slide.speaker_transcripts, chunk_paths, strict=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if len(chunk_paths) == 1:
|
||||
shutil.move(chunk_paths[0], output_file)
|
||||
else:
|
||||
print(
|
||||
f" Concatenating {len(chunk_paths)} chunks for slide "
|
||||
f"{slide.slide_number} with ffmpeg"
|
||||
)
|
||||
await _concat_with_ffmpeg(chunk_paths, output_file)
|
||||
|
||||
duration_seconds = await _get_audio_duration(output_file)
|
||||
duration_in_frames = math.ceil(duration_seconds * FPS)
|
||||
|
||||
return SlideAudioResult(
|
||||
slide_number=slide.slide_number,
|
||||
audio_file=output_file,
|
||||
duration_seconds=duration_seconds,
|
||||
duration_in_frames=max(duration_in_frames, DEFAULT_DURATION_IN_FRAMES),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating audio for slide {slide.slide_number}: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
for p in chunk_paths:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(p)
|
||||
|
||||
tasks = [generate_audio_for_slide(slide) for slide in slides]
|
||||
audio_results = await asyncio.gather(*tasks)
|
||||
|
||||
audio_results_sorted = sorted(audio_results, key=lambda r: r.slide_number)
|
||||
|
||||
print(
|
||||
f"Generated audio for {len(audio_results_sorted)} slides "
|
||||
f"(total duration: {sum(r.duration_seconds for r in audio_results_sorted):.1f}s)"
|
||||
)
|
||||
|
||||
return {"slide_audio_results": audio_results_sorted}
|
||||
|
||||
|
||||
async def _get_audio_duration(file_path: str) -> float:
|
||||
"""Get audio duration in seconds using ffprobe (via python-ffmpeg).
|
||||
|
||||
Falls back to file-size estimation if ffprobe fails.
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
file_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||
if proc.returncode == 0 and stdout.strip():
|
||||
return float(stdout.strip())
|
||||
except Exception as e:
|
||||
print(f"ffprobe failed for {file_path}: {e!s}, using file-size estimation")
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_path.endswith(".wav"):
|
||||
return file_size / (16000 * 2)
|
||||
else:
|
||||
return file_size / 16000
|
||||
except Exception:
|
||||
return DEFAULT_DURATION_IN_FRAMES / FPS
|
||||
|
||||
|
||||
async def _assign_themes_with_llm(
|
||||
llm, slides: list[SlideContent]
|
||||
) -> dict[int, tuple[str, str]]:
|
||||
"""Ask the LLM to assign a theme+mode to each slide in one call.
|
||||
|
||||
Returns a dict mapping slide_number → (theme, mode).
|
||||
Falls back to round-robin if the LLM response can't be parsed.
|
||||
"""
|
||||
total = len(slides)
|
||||
slide_summaries = [
|
||||
{
|
||||
"slide_number": s.slide_number,
|
||||
"title": s.title,
|
||||
"subtitle": s.subtitle or "",
|
||||
"background_explanation": s.background_explanation or "",
|
||||
}
|
||||
for s in slides
|
||||
]
|
||||
|
||||
system = get_theme_assignment_system_prompt()
|
||||
user = build_theme_assignment_user_prompt(slide_summaries)
|
||||
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
SystemMessage(content=system),
|
||||
HumanMessage(content=user),
|
||||
]
|
||||
)
|
||||
|
||||
text = response.content.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
text = "\n".join(
|
||||
line for line in lines if not line.strip().startswith("```")
|
||||
).strip()
|
||||
|
||||
assignments = json.loads(text)
|
||||
valid_themes = set(THEME_PRESETS)
|
||||
result: dict[int, tuple[str, str]] = {}
|
||||
for entry in assignments:
|
||||
sn = entry.get("slide_number")
|
||||
theme = entry.get("theme", "").upper()
|
||||
mode = entry.get("mode", "dark").lower()
|
||||
if sn and theme in valid_themes and mode in ("dark", "light"):
|
||||
result[sn] = (theme, mode)
|
||||
|
||||
if len(result) == total:
|
||||
print(
|
||||
"LLM theme assignment: "
|
||||
+ ", ".join(f"S{sn}={t}/{m}" for sn, (t, m) in sorted(result.items()))
|
||||
)
|
||||
return result
|
||||
|
||||
print(
|
||||
f"LLM returned {len(result)}/{total} valid assignments, "
|
||||
"filling gaps with fallback"
|
||||
)
|
||||
for s in slides:
|
||||
if s.slide_number not in result:
|
||||
result[s.slide_number] = pick_theme_and_mode_fallback(
|
||||
s.slide_number - 1, total
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM theme assignment failed ({e!s}), using fallback")
|
||||
return {
|
||||
s.slide_number: pick_theme_and_mode_fallback(s.slide_number - 1, total)
|
||||
for s in slides
|
||||
}
|
||||
|
||||
|
||||
async def assign_slide_themes(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Assign a theme preset + dark/light mode to every slide via a single LLM call.
|
||||
|
||||
Runs in parallel with audio generation since it only needs slide metadata.
|
||||
"""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
raise RuntimeError(f"No LLM configured for search space {search_space_id}")
|
||||
|
||||
slides = state.slides or []
|
||||
assignments = await _assign_themes_with_llm(llm, slides)
|
||||
return {"slide_theme_assignments": assignments}
|
||||
|
||||
|
||||
async def generate_slide_scene_codes(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate Remotion component code for each slide using LLM.
|
||||
|
||||
Reads pre-assigned themes from state (produced by the parallel
|
||||
assign_slide_themes node) and generates scene code concurrently.
|
||||
"""
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
raise RuntimeError(f"No LLM configured for search space {search_space_id}")
|
||||
|
||||
slides = state.slides or []
|
||||
audio_results = state.slide_audio_results or []
|
||||
|
||||
audio_map: dict[int, SlideAudioResult] = {r.slide_number: r for r in audio_results}
|
||||
total_slides = len(slides)
|
||||
|
||||
theme_assignments = state.slide_theme_assignments or {}
|
||||
|
||||
async def _generate_scene_for_slide(slide: SlideContent) -> SlideSceneCode:
|
||||
audio = audio_map.get(slide.slide_number)
|
||||
duration = audio.duration_in_frames if audio else DEFAULT_DURATION_IN_FRAMES
|
||||
|
||||
theme, mode = theme_assignments.get(
|
||||
slide.slide_number,
|
||||
pick_theme_and_mode_fallback(slide.slide_number - 1, total_slides),
|
||||
)
|
||||
|
||||
user_prompt = build_scene_generation_user_prompt(
|
||||
slide_number=slide.slide_number,
|
||||
total_slides=total_slides,
|
||||
title=slide.title,
|
||||
subtitle=slide.subtitle,
|
||||
content_in_markdown=slide.content_in_markdown,
|
||||
background_explanation=slide.background_explanation,
|
||||
duration_in_frames=duration,
|
||||
theme=theme,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=REMOTION_SCENE_SYSTEM_PROMPT),
|
||||
HumanMessage(content=user_prompt),
|
||||
]
|
||||
|
||||
print(
|
||||
f"Generating scene code for slide {slide.slide_number}/{total_slides}: "
|
||||
f'"{slide.title}" ({duration} frames)'
|
||||
)
|
||||
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
code, scene_title = _extract_code_and_title(llm_response.content)
|
||||
|
||||
code = await _refine_if_needed(llm, code, slide.slide_number)
|
||||
|
||||
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
|
||||
|
||||
return SlideSceneCode(
|
||||
slide_number=slide.slide_number,
|
||||
code=code,
|
||||
title=scene_title or slide.title,
|
||||
)
|
||||
|
||||
scene_codes = list(
|
||||
await asyncio.gather(*[_generate_scene_for_slide(s) for s in slides])
|
||||
)
|
||||
|
||||
return {"slide_scene_codes": scene_codes}
|
||||
|
||||
|
||||
def _extract_code_and_title(content: str) -> tuple[str, str | None]:
|
||||
"""Extract code and optional title from LLM response.
|
||||
|
||||
The LLM may return a JSON object like the POC's structured output:
|
||||
{ "code": "...", "title": "..." }
|
||||
Or it may return raw code (with optional markdown fences).
|
||||
|
||||
Returns (code, title) where title may be None.
|
||||
"""
|
||||
text = content.strip()
|
||||
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict) and "code" in parsed:
|
||||
return parsed["code"], parsed.get("title")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
json_start = text.find("{")
|
||||
json_end = text.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
try:
|
||||
parsed = json.loads(text[json_start:json_end])
|
||||
if isinstance(parsed, dict) and "code" in parsed:
|
||||
return parsed["code"], parsed.get("title")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
code = text
|
||||
if code.startswith("```"):
|
||||
lines = code.split("\n")
|
||||
start = 1
|
||||
end = len(lines)
|
||||
for i in range(len(lines) - 1, 0, -1):
|
||||
if lines[i].strip().startswith("```"):
|
||||
end = i
|
||||
break
|
||||
code = "\n".join(lines[start:end]).strip()
|
||||
|
||||
return code, None
|
||||
|
||||
|
||||
async def _refine_if_needed(llm, code: str, slide_number: int) -> str:
|
||||
"""Attempt basic syntax validation and auto-repair via LLM if needed.
|
||||
|
||||
Raises RuntimeError if the code is still invalid after MAX_REFINE_ATTEMPTS,
|
||||
matching the POC's behavior where a failed slide aborts the pipeline.
|
||||
"""
|
||||
error = _basic_syntax_check(code)
|
||||
if error is None:
|
||||
return code
|
||||
|
||||
for attempt in range(1, MAX_REFINE_ATTEMPTS + 1):
|
||||
print(
|
||||
f"Slide {slide_number}: syntax issue (attempt {attempt}/{MAX_REFINE_ATTEMPTS}): {error}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=REFINE_SCENE_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"Here is the broken Remotion component code:\n\n{code}\n\n"
|
||||
f"Compilation error:\n{error}\n\nFix the code."
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
code, _ = _extract_code_and_title(response.content)
|
||||
|
||||
error = _basic_syntax_check(code)
|
||||
if error is None:
|
||||
print(f"Slide {slide_number}: fixed on attempt {attempt}")
|
||||
return code
|
||||
|
||||
raise RuntimeError(
|
||||
f"Slide {slide_number} failed to compile after {MAX_REFINE_ATTEMPTS} "
|
||||
f"refine attempts. Last error: {error}"
|
||||
)
|
||||
|
||||
|
||||
def _basic_syntax_check(code: str) -> str | None:
|
||||
"""Run a lightweight syntax check on the generated code.
|
||||
|
||||
Full Babel-based compilation happens on the frontend. This backend check
|
||||
catches the most common LLM code-generation mistakes so the refine loop
|
||||
can fix them before persisting.
|
||||
|
||||
Returns an error description or None if the code looks valid.
|
||||
"""
|
||||
if not code or not code.strip():
|
||||
return "Empty code"
|
||||
|
||||
if "export" not in code and "MyComposition" not in code:
|
||||
return "Missing exported component (expected 'export const MyComposition')"
|
||||
|
||||
brace_count = 0
|
||||
paren_count = 0
|
||||
bracket_count = 0
|
||||
for ch in code:
|
||||
if ch == "{":
|
||||
brace_count += 1
|
||||
elif ch == "}":
|
||||
brace_count -= 1
|
||||
elif ch == "(":
|
||||
paren_count += 1
|
||||
elif ch == ")":
|
||||
paren_count -= 1
|
||||
elif ch == "[":
|
||||
bracket_count += 1
|
||||
elif ch == "]":
|
||||
bracket_count -= 1
|
||||
|
||||
if brace_count < 0:
|
||||
return "Unmatched closing brace '}'"
|
||||
if paren_count < 0:
|
||||
return "Unmatched closing parenthesis ')'"
|
||||
if bracket_count < 0:
|
||||
return "Unmatched closing bracket ']'"
|
||||
|
||||
if brace_count != 0:
|
||||
return f"Unbalanced braces: {brace_count} unclosed"
|
||||
if paren_count != 0:
|
||||
return f"Unbalanced parentheses: {paren_count} unclosed"
|
||||
if bracket_count != 0:
|
||||
return f"Unbalanced brackets: {bracket_count} unclosed"
|
||||
|
||||
if "useCurrentFrame" not in code:
|
||||
return "Missing useCurrentFrame() — required for Remotion animations"
|
||||
|
||||
if "AbsoluteFill" not in code:
|
||||
return "Missing AbsoluteFill — required as the root layout component"
|
||||
|
||||
return None
|
||||
509
surfsense_backend/app/agents/video_presentation/prompts.py
Normal file
509
surfsense_backend/app/agents/video_presentation/prompts.py
Normal file
|
|
@ -0,0 +1,509 @@
|
|||
import datetime
|
||||
|
||||
# TODO: move these to config file
|
||||
MAX_SLIDES = 5
|
||||
FPS = 30
|
||||
DEFAULT_DURATION_IN_FRAMES = 300
|
||||
|
||||
THEME_PRESETS = [
|
||||
"TERRA",
|
||||
"OCEAN",
|
||||
"SUNSET",
|
||||
"EMERALD",
|
||||
"ECLIPSE",
|
||||
"ROSE",
|
||||
"FROST",
|
||||
"NEBULA",
|
||||
"AURORA",
|
||||
"CORAL",
|
||||
"MIDNIGHT",
|
||||
"AMBER",
|
||||
"LAVENDER",
|
||||
"STEEL",
|
||||
"CITRUS",
|
||||
"CHERRY",
|
||||
]
|
||||
|
||||
THEME_DESCRIPTIONS: dict[str, str] = {
|
||||
"TERRA": "Warm earthy tones — terracotta, olive. Heritage, tradition, organic warmth.",
|
||||
"OCEAN": "Cool oceanic depth — teal, coral accents. Calm, marine, fluid elegance.",
|
||||
"SUNSET": "Vibrant warm energy — orange, purple. Passion, creativity, bold expression.",
|
||||
"EMERALD": "Fresh natural life — green, mint. Growth, health, sustainability.",
|
||||
"ECLIPSE": "Dramatic luxury — black, gold. Premium, power, prestige.",
|
||||
"ROSE": "Soft elegance — dusty pink, mauve. Beauty, care, refined femininity.",
|
||||
"FROST": "Crisp clarity — ice blue, silver. Tech, data, precision analytics.",
|
||||
"NEBULA": "Cosmic mystery — magenta, deep purple. AI, innovation, cutting-edge future.",
|
||||
"AURORA": "Ethereal northern lights — green-teal, violet. Mystical, transformative, wonder.",
|
||||
"CORAL": "Tropical warmth — coral, turquoise. Inviting, lively, community.",
|
||||
"MIDNIGHT": "Deep sophistication — navy, silver. Contemplative, trust, authority.",
|
||||
"AMBER": "Rich honey warmth — amber, brown. Comfort, wisdom, organic richness.",
|
||||
"LAVENDER": "Gentle dreaminess — purple, lilac. Calm, imaginative, serene.",
|
||||
"STEEL": "Industrial strength — gray, steel blue. Modern professional, reliability.",
|
||||
"CITRUS": "Bright optimism — yellow, lime. Energy, joy, fresh starts.",
|
||||
"CHERRY": "Bold impact — deep red, dark. Power, urgency, passionate conviction.",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-based theme assignment (replaces keyword-based pick_theme_and_mode)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THEME_ASSIGNMENT_SYSTEM_PROMPT = """You are a visual design director assigning color themes to presentation slides.
|
||||
Given a list of slides, assign each slide a theme preset and color mode (dark or light).
|
||||
|
||||
Available themes (name — description):
|
||||
{theme_list}
|
||||
|
||||
Rules:
|
||||
1. Pick the theme that best matches each slide's mood, content, and visual direction.
|
||||
2. Maximize visual variety — avoid repeating the same theme on consecutive slides.
|
||||
3. Mix dark and light modes across the presentation for contrast and rhythm.
|
||||
4. Opening slides often benefit from a bold dark theme; closing/summary slides can go either way.
|
||||
5. The "background_explanation" field is the primary signal — it describes the intended mood and color direction.
|
||||
|
||||
Return ONLY a JSON array (no markdown fences, no explanation):
|
||||
[
|
||||
{{"slide_number": 1, "theme": "THEME_NAME", "mode": "dark"}},
|
||||
{{"slide_number": 2, "theme": "THEME_NAME", "mode": "light"}}
|
||||
]
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_theme_assignment_user_prompt(
|
||||
slides: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Build the user prompt for LLM theme assignment.
|
||||
|
||||
*slides* is a list of dicts with keys: slide_number, title, subtitle,
|
||||
background_explanation (mood).
|
||||
"""
|
||||
lines = ["Assign a theme and mode to each of these slides:", ""]
|
||||
for s in slides:
|
||||
lines.append(
|
||||
f'Slide {s["slide_number"]}: "{s["title"]}" '
|
||||
f'(subtitle: "{s.get("subtitle", "")}") — '
|
||||
f'Mood: "{s.get("background_explanation", "neutral")}"'
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_theme_assignment_system_prompt() -> str:
|
||||
"""Return the theme assignment system prompt with the full theme list injected."""
|
||||
theme_list = "\n".join(
|
||||
f"- {name}: {desc}" for name, desc in THEME_DESCRIPTIONS.items()
|
||||
)
|
||||
return THEME_ASSIGNMENT_SYSTEM_PROMPT.format(theme_list=theme_list)
|
||||
|
||||
|
||||
def pick_theme_and_mode_fallback(
|
||||
slide_index: int, total_slides: int
|
||||
) -> tuple[str, str]:
|
||||
"""Simple round-robin fallback when LLM theme assignment fails."""
|
||||
theme = THEME_PRESETS[slide_index % len(THEME_PRESETS)]
|
||||
mode = "dark" if slide_index % 2 == 0 else "light"
|
||||
if total_slides == 1:
|
||||
mode = "dark"
|
||||
return theme, mode
|
||||
|
||||
|
||||
def get_slide_generation_prompt(user_prompt: str | None = None) -> str:
|
||||
return f"""
|
||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
<video_presentation_system>
|
||||
You are a content-to-slides converter. You receive raw source content (articles, notes, transcripts,
|
||||
product descriptions, chat conversations, etc.) and break it into a sequence of presentation slides
|
||||
for a video presentation with voiceover narration.
|
||||
|
||||
{
|
||||
f'''
|
||||
You **MUST** strictly adhere to the following user instruction while generating the slides:
|
||||
<user_instruction>
|
||||
{user_prompt}
|
||||
</user_instruction>
|
||||
'''
|
||||
if user_prompt
|
||||
else ""
|
||||
}
|
||||
|
||||
<input>
|
||||
- '<source_content>': A block of text containing the information to be presented. This could be
|
||||
research findings, an article summary, a detailed outline, user chat history, or any relevant
|
||||
raw information. The content serves as the factual basis for the video presentation.
|
||||
</input>
|
||||
|
||||
<output_format>
|
||||
A JSON object containing the presentation slides:
|
||||
{{
|
||||
"slides": [
|
||||
{{
|
||||
"slide_number": 1,
|
||||
"title": "Concise slide title",
|
||||
"subtitle": "One-line subtitle or tagline",
|
||||
"content_in_markdown": "## Heading\\n- Bullet point 1\\n- **Bold text**\\n- Bullet point 3",
|
||||
"speaker_transcripts": [
|
||||
"First narration sentence for this slide.",
|
||||
"Second narration sentence expanding on the point.",
|
||||
"Third sentence wrapping up this slide."
|
||||
],
|
||||
"background_explanation": "Emotional mood and color direction for this slide"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</output_format>
|
||||
|
||||
<guidelines>
|
||||
=== SLIDE COUNT ===
|
||||
|
||||
Dynamically decide the number of slides between 1 and {MAX_SLIDES} (inclusive).
|
||||
Base your decision entirely on the content's depth, richness, and how many distinct ideas it contains.
|
||||
Thin or simple content should produce fewer slides; dense or multi-faceted content may use more.
|
||||
Do NOT inflate or pad slides to reach {
|
||||
MAX_SLIDES
|
||||
} — only use what the content genuinely warrants.
|
||||
Do NOT treat {MAX_SLIDES} as a target; it is a hard ceiling, not a goal.
|
||||
|
||||
=== SLIDE STRUCTURE ===
|
||||
|
||||
- Each slide should cover ONE distinct key idea or section.
|
||||
- Keep slides focused: 2-5 bullet points of content per slide max.
|
||||
- The first slide should be a title/intro slide.
|
||||
- The last slide should be a summary or closing slide ONLY if there are 3+ slides.
|
||||
For 1-2 slides, skip the closing slide — just cover the content.
|
||||
- Do NOT create a separate closing slide if its content would just repeat earlier slides.
|
||||
|
||||
=== CONTENT FIELDS ===
|
||||
|
||||
- Write speaker_transcripts as if a human presenter is narrating — natural, conversational, 2-4 sentences per slide.
|
||||
These will be converted to TTS audio, so write in a way that sounds great when spoken aloud.
|
||||
- background_explanation should describe a visual style matching the slide's mood:
|
||||
- Describe the emotional feel: "warm and organic", "dramatic and urgent", "clean and optimistic",
|
||||
"technical and precise", "celebratory", "earthy and grounded", "cosmic and futuristic"
|
||||
- Mention color direction: warm tones, cool tones, earth tones, neon accents, gold/black, etc.
|
||||
- Vary the mood across slides — do NOT always say "dark blue gradient".
|
||||
- content_in_markdown should use proper markdown: ## headings, **bold**, - bullets, etc.
|
||||
|
||||
=== NARRATION QUALITY ===
|
||||
|
||||
- Speaker transcripts should explain the slide content in an engaging, presenter-like voice.
|
||||
- Keep narration concise: 2-4 sentences per slide (targeting ~10-15 seconds of audio per slide).
|
||||
- The narration should add context beyond what's on the slide — don't just read the bullets.
|
||||
- Use natural language: contractions, conversational tone, occasional enthusiasm.
|
||||
</guidelines>
|
||||
|
||||
<examples>
|
||||
Input: "Quantum computing uses quantum bits or qubits which can exist in multiple states simultaneously due to superposition."
|
||||
|
||||
Output:
|
||||
{{
|
||||
"slides": [
|
||||
{{
|
||||
"slide_number": 1,
|
||||
"title": "Quantum Computing",
|
||||
"subtitle": "Beyond Classical Bits",
|
||||
"content_in_markdown": "## The Quantum Leap\\n- Classical computers use **bits** (0 or 1)\\n- Quantum computers use **qubits**\\n- Qubits leverage **superposition**",
|
||||
"speaker_transcripts": [
|
||||
"Let's explore quantum computing, a technology that's fundamentally different from the computers we use every day.",
|
||||
"While traditional computers work with bits that are either zero or one, quantum computers use something called qubits.",
|
||||
"The magic of qubits is superposition — they can exist in multiple states at the same time."
|
||||
],
|
||||
"background_explanation": "Cosmic and futuristic with deep purple and magenta tones, evoking the mystery of quantum mechanics"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</examples>
|
||||
|
||||
Transform the source material into well-structured presentation slides with engaging narration.
|
||||
Ensure each slide has a clear visual mood and natural-sounding speaker transcripts.
|
||||
</video_presentation_system>
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Remotion scene code generation prompt
|
||||
# Ported from RemotionTets POC /api/generate system prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
REMOTION_SCENE_SYSTEM_PROMPT = """
|
||||
You are a Remotion component generator that creates cinematic, modern motion graphics.
|
||||
Generate a single self-contained React component that uses Remotion.
|
||||
|
||||
=== THEME PRESETS (pick ONE per slide — see user prompt for which to use) ===
|
||||
|
||||
Each slide MUST use a DIFFERENT preset. The user prompt will tell you which preset to use.
|
||||
Use ALL colors from that preset — background, surface, text, accent, glow. Do NOT mix presets.
|
||||
|
||||
TERRA (warm earth — terracotta + olive):
|
||||
dark: bg #1C1510 surface #261E16 border #3D3024 text #E8DDD0 muted #9A8A78 accent #C2623D secondary #7D8C52 glow rgba(194,98,61,0.12)
|
||||
light: bg #F7F0E8 surface #FFF8F0 border #DDD0BF text #2C1D0E muted #8A7A68 accent #B85430 secondary #6B7A42 glow rgba(184,84,48,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 30% 80%, rgba(194,98,61,0.18), transparent 60%), linear-gradient(180deg, #1C1510, #261E16)
|
||||
gradient-light: radial-gradient(ellipse at 70% 20%, rgba(107,122,66,0.12), transparent 55%), linear-gradient(180deg, #F7F0E8, #FFF8F0)
|
||||
|
||||
OCEAN (cool depth — teal + coral):
|
||||
dark: bg #0B1A1E surface #122428 border #1E3740 text #D5EAF0 muted #6A9AA8 accent #1DB6A8 secondary #E87461 glow rgba(29,182,168,0.12)
|
||||
light: bg #F0F8FA surface #FFFFFF border #C8E0E8 text #0E2830 muted #5A8A98 accent #0EA69A secondary #D05F4E glow rgba(14,166,154,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 80% 30%, rgba(29,182,168,0.20), transparent 55%), radial-gradient(circle at 20% 80%, rgba(232,116,97,0.10), transparent 50%), #0B1A1E
|
||||
gradient-light: radial-gradient(ellipse at 20% 40%, rgba(14,166,154,0.10), transparent 55%), linear-gradient(180deg, #F0F8FA, #FFFFFF)
|
||||
|
||||
SUNSET (warm energy — orange + purple):
|
||||
dark: bg #1E130F surface #2A1B14 border #42291C text #F0DDD0 muted #A08878 accent #E86A20 secondary #A855C0 glow rgba(232,106,32,0.12)
|
||||
light: bg #FFF5ED surface #FFFFFF border #EADAC8 text #2E1508 muted #907860 accent #D05A18 secondary #9045A8 glow rgba(208,90,24,0.08)
|
||||
gradient-dark: linear-gradient(135deg, rgba(232,106,32,0.15) 0%, transparent 40%), radial-gradient(circle at 80% 70%, rgba(168,85,192,0.15), transparent 50%), #1E130F
|
||||
gradient-light: linear-gradient(135deg, rgba(208,90,24,0.08) 0%, rgba(144,69,168,0.06) 100%), #FFF5ED
|
||||
|
||||
EMERALD (fresh life — green + mint):
|
||||
dark: bg #0B1E14 surface #12281A border #1E3C28 text #D0F0E0 muted #5EA880 accent #10B981 secondary #84CC16 glow rgba(16,185,129,0.12)
|
||||
light: bg #F0FAF5 surface #FFFFFF border #C0E8D0 text #0E2C18 muted #489068 accent #059669 secondary #65A30D glow rgba(5,150,105,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 50% 50%, rgba(16,185,129,0.18), transparent 60%), linear-gradient(180deg, #0B1E14, #12281A)
|
||||
gradient-light: radial-gradient(ellipse at 60% 30%, rgba(101,163,13,0.10), transparent 55%), linear-gradient(180deg, #F0FAF5, #FFFFFF)
|
||||
|
||||
ECLIPSE (dramatic — black + gold):
|
||||
dark: bg #100C05 surface #1A1508 border #2E2510 text #D4B96A muted #8A7840 accent #E8B830 secondary #C09020 glow rgba(232,184,48,0.14)
|
||||
light: bg #FAF6ED surface #FFFFFF border #E0D8C0 text #1A1408 muted #7A6818 accent #C09820 secondary #A08018 glow rgba(192,152,32,0.08)
|
||||
gradient-dark: radial-gradient(circle at 50% 40%, rgba(232,184,48,0.20), transparent 50%), radial-gradient(ellipse at 50% 90%, rgba(192,144,32,0.08), transparent 50%), #100C05
|
||||
gradient-light: radial-gradient(circle at 50% 40%, rgba(192,152,32,0.10), transparent 55%), linear-gradient(180deg, #FAF6ED, #FFFFFF)
|
||||
|
||||
ROSE (soft elegance — dusty pink + mauve):
|
||||
dark: bg #1E1018 surface #281820 border #3D2830 text #F0D8E0 muted #A08090 accent #E4508C secondary #B06498 glow rgba(228,80,140,0.12)
|
||||
light: bg #FDF2F5 surface #FFFFFF border #F0D0D8 text #2C1018 muted #906878 accent #D43D78 secondary #9A5080 glow rgba(212,61,120,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 70% 30%, rgba(228,80,140,0.18), transparent 55%), radial-gradient(circle at 20% 80%, rgba(176,100,152,0.10), transparent 50%), #1E1018
|
||||
gradient-light: radial-gradient(ellipse at 30% 60%, rgba(212,61,120,0.08), transparent 55%), linear-gradient(180deg, #FDF2F5, #FFFFFF)
|
||||
|
||||
FROST (crisp clarity — ice blue + silver):
|
||||
dark: bg #0A1520 surface #101D2A border #1A3040 text #D0E5F5 muted #6090B0 accent #5AB4E8 secondary #8BA8C0 glow rgba(90,180,232,0.12)
|
||||
light: bg #F0F6FC surface #FFFFFF border #C8D8E8 text #0C1820 muted #5080A0 accent #3A96D0 secondary #7090A8 glow rgba(58,150,208,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 40% 20%, rgba(90,180,232,0.16), transparent 55%), linear-gradient(180deg, #0A1520, #101D2A)
|
||||
gradient-light: radial-gradient(ellipse at 50% 50%, rgba(58,150,208,0.08), transparent 55%), linear-gradient(180deg, #F0F6FC, #FFFFFF)
|
||||
|
||||
NEBULA (cosmic — magenta + deep purple):
|
||||
dark: bg #150A1E surface #1E1028 border #351A48 text #E0D0F0 muted #8060A0 accent #C850E0 secondary #8030C0 glow rgba(200,80,224,0.14)
|
||||
light: bg #F8F0FF surface #FFFFFF border #E0C8F0 text #1A0A24 muted #7050A0 accent #A840C0 secondary #6820A0 glow rgba(168,64,192,0.08)
|
||||
gradient-dark: radial-gradient(circle at 60% 40%, rgba(200,80,224,0.18), transparent 50%), radial-gradient(ellipse at 30% 80%, rgba(128,48,192,0.12), transparent 50%), #150A1E
|
||||
gradient-light: radial-gradient(circle at 40% 30%, rgba(168,64,192,0.10), transparent 55%), linear-gradient(180deg, #F8F0FF, #FFFFFF)
|
||||
|
||||
AURORA (ethereal lights — green-teal + violet):
|
||||
dark: bg #0A1A1A surface #102020 border #1A3838 text #D0F0F0 muted #60A0A0 accent #30D0B0 secondary #8040D0 glow rgba(48,208,176,0.12)
|
||||
light: bg #F0FAF8 surface #FFFFFF border #C0E8E0 text #0A2020 muted #508080 accent #20B090 secondary #6830B0 glow rgba(32,176,144,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 30% 70%, rgba(48,208,176,0.18), transparent 55%), radial-gradient(circle at 70% 30%, rgba(128,64,208,0.12), transparent 50%), #0A1A1A
|
||||
gradient-light: radial-gradient(ellipse at 50% 40%, rgba(32,176,144,0.10), transparent 55%), linear-gradient(180deg, #F0FAF8, #FFFFFF)
|
||||
|
||||
CORAL (tropical warmth — coral + turquoise):
|
||||
dark: bg #1E0F0F surface #281818 border #402828 text #F0D8D8 muted #A07070 accent #F06050 secondary #30B8B0 glow rgba(240,96,80,0.12)
|
||||
light: bg #FFF5F3 surface #FFFFFF border #F0D0C8 text #2E1010 muted #906060 accent #E04838 secondary #20A098 glow rgba(224,72,56,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 60% 60%, rgba(240,96,80,0.18), transparent 55%), radial-gradient(circle at 30% 30%, rgba(48,184,176,0.10), transparent 50%), #1E0F0F
|
||||
gradient-light: radial-gradient(ellipse at 40% 50%, rgba(224,72,56,0.08), transparent 55%), linear-gradient(180deg, #FFF5F3, #FFFFFF)
|
||||
|
||||
MIDNIGHT (deep sophistication — navy + silver):
|
||||
dark: bg #080C18 surface #0E1420 border #1A2438 text #C8D8F0 muted #5070A0 accent #4080E0 secondary #A0B0D0 glow rgba(64,128,224,0.12)
|
||||
light: bg #F0F2F8 surface #FFFFFF border #C8D0E0 text #101828 muted #506080 accent #3060C0 secondary #8090B0 glow rgba(48,96,192,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 50% 30%, rgba(64,128,224,0.16), transparent 55%), linear-gradient(180deg, #080C18, #0E1420)
|
||||
gradient-light: radial-gradient(ellipse at 50% 50%, rgba(48,96,192,0.08), transparent 55%), linear-gradient(180deg, #F0F2F8, #FFFFFF)
|
||||
|
||||
AMBER (rich honey warmth — amber + brown):
|
||||
dark: bg #1A1208 surface #221A0E border #3A2C18 text #F0E0C0 muted #A09060 accent #E0A020 secondary #C08030 glow rgba(224,160,32,0.12)
|
||||
light: bg #FFF8E8 surface #FFFFFF border #E8D8B8 text #2A1C08 muted #907840 accent #C88810 secondary #A86820 glow rgba(200,136,16,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 40% 60%, rgba(224,160,32,0.18), transparent 55%), linear-gradient(180deg, #1A1208, #221A0E)
|
||||
gradient-light: radial-gradient(ellipse at 60% 40%, rgba(200,136,16,0.10), transparent 55%), linear-gradient(180deg, #FFF8E8, #FFFFFF)
|
||||
|
||||
LAVENDER (gentle dreaminess — purple + lilac):
|
||||
dark: bg #14101E surface #1C1628 border #302840 text #E0D8F0 muted #8070A0 accent #A060E0 secondary #C090D0 glow rgba(160,96,224,0.12)
|
||||
light: bg #F8F0FF surface #FFFFFF border #E0D0F0 text #1C1028 muted #706090 accent #8848C0 secondary #A878B8 glow rgba(136,72,192,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 60% 40%, rgba(160,96,224,0.18), transparent 55%), radial-gradient(circle at 30% 70%, rgba(192,144,208,0.10), transparent 50%), #14101E
|
||||
gradient-light: radial-gradient(ellipse at 40% 30%, rgba(136,72,192,0.10), transparent 55%), linear-gradient(180deg, #F8F0FF, #FFFFFF)
|
||||
|
||||
STEEL (industrial strength — gray + steel blue):
|
||||
dark: bg #101214 surface #181C20 border #282E38 text #D0D8E0 muted #708090 accent #5088B0 secondary #90A0B0 glow rgba(80,136,176,0.12)
|
||||
light: bg #F2F4F6 surface #FFFFFF border #D0D8E0 text #181C24 muted #607080 accent #3870A0 secondary #708898 glow rgba(56,112,160,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 50% 50%, rgba(80,136,176,0.14), transparent 55%), linear-gradient(180deg, #101214, #181C20)
|
||||
gradient-light: radial-gradient(ellipse at 50% 40%, rgba(56,112,160,0.08), transparent 55%), linear-gradient(180deg, #F2F4F6, #FFFFFF)
|
||||
|
||||
CITRUS (bright optimism — yellow + lime):
|
||||
dark: bg #181808 surface #202010 border #383818 text #F0F0C0 muted #A0A060 accent #E8D020 secondary #90D030 glow rgba(232,208,32,0.12)
|
||||
light: bg #FFFFF0 surface #FFFFFF border #E8E8C0 text #282808 muted #808040 accent #C8B010 secondary #70B020 glow rgba(200,176,16,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 40% 40%, rgba(232,208,32,0.18), transparent 55%), radial-gradient(circle at 70% 70%, rgba(144,208,48,0.10), transparent 50%), #181808
|
||||
gradient-light: radial-gradient(ellipse at 50% 30%, rgba(200,176,16,0.10), transparent 55%), linear-gradient(180deg, #FFFFF0, #FFFFFF)
|
||||
|
||||
CHERRY (bold impact — deep red + dark):
|
||||
dark: bg #1A0808 surface #241010 border #401818 text #F0D0D0 muted #A06060 accent #D02030 secondary #E05060 glow rgba(208,32,48,0.14)
|
||||
light: bg #FFF0F0 surface #FFFFFF border #F0C8C8 text #280808 muted #904848 accent #B01828 secondary #C83848 glow rgba(176,24,40,0.08)
|
||||
gradient-dark: radial-gradient(ellipse at 50% 40%, rgba(208,32,48,0.20), transparent 50%), linear-gradient(180deg, #1A0808, #241010)
|
||||
gradient-light: radial-gradient(ellipse at 50% 50%, rgba(176,24,40,0.10), transparent 55%), linear-gradient(180deg, #FFF0F0, #FFFFFF)
|
||||
|
||||
=== SHARED TOKENS (use with any theme above) ===
|
||||
|
||||
SPACING: xs 8px, sm 16px, md 24px, lg 32px, xl 48px, 2xl 64px, 3xl 96px, 4xl 128px
|
||||
TYPOGRAPHY: fontFamily "Inter, system-ui, -apple-system, sans-serif"
|
||||
caption 14px/1.4, body 18px/1.6, subhead 24px/1.4, title 40px/1.2 w600, headline 64px/1.1 w700, display 96px/1.0 w800
|
||||
letterSpacing: tight "-0.02em", normal "0", wide "0.05em"
|
||||
BORDER RADIUS: 12px (cards), 8px (buttons), 9999px (pills)
|
||||
|
||||
=== VISUAL VARIETY (CRITICAL) ===
|
||||
|
||||
The user prompt assigns each slide a specific theme preset AND mode (dark/light).
|
||||
You MUST use EXACTLY the assigned preset and mode. Additionally:
|
||||
|
||||
1. Use the preset's gradient as the AbsoluteFill background.
|
||||
2. Use the preset's accent/secondary colors for highlights, pill badges, and card accents.
|
||||
3. Use the preset's glow value for all boxShadow effects.
|
||||
4. LAYOUT VARIATION: Vary layout between slides:
|
||||
- One slide: bold centered headline + subtle stat
|
||||
- Another: two-column card layout
|
||||
- Another: single large number or quote as hero
|
||||
Do NOT use the same layout pattern for every slide.
|
||||
|
||||
=== LAYOUT RULES (CRITICAL — elements must NEVER overlap) ===
|
||||
|
||||
The canvas is 1920x1080. You MUST use a SINGLE-LAYER layout. NO stacking, NO multiple AbsoluteFill layers.
|
||||
|
||||
STRUCTURE — every component must follow this exact pattern:
|
||||
<AbsoluteFill style={{ backgroundColor: "...", display: "flex", flexDirection: "column", justifyContent: "center", alignItems: "center", padding: 80 }}>
|
||||
{/* ALL content goes here as direct children in normal flow */}
|
||||
</AbsoluteFill>
|
||||
|
||||
ABSOLUTE RULES:
|
||||
- Use exactly ONE AbsoluteFill as the root. Set its background color/gradient via its style prop.
|
||||
- NEVER nest AbsoluteFill inside AbsoluteFill.
|
||||
- NEVER use position "absolute" or position "fixed" on ANY element.
|
||||
- NEVER use multiple layers or z-index.
|
||||
- ALL elements must be in normal document flow inside the single root AbsoluteFill.
|
||||
|
||||
SPACING:
|
||||
- Root padding: 80px on all sides (safe area).
|
||||
- Use flexDirection "column" with gap for vertical stacking, flexDirection "row" with gap for horizontal.
|
||||
- Minimum gap between elements: 24px vertical, 32px horizontal.
|
||||
- Text hierarchy gaps: headline→subheading 16px, subheading→body 12px, body→button 32px.
|
||||
- Cards/panels: padding 32px-48px, borderRadius 12px.
|
||||
- NEVER use margin to space siblings — always use the parent's gap property.
|
||||
|
||||
=== DESIGN STYLE ===
|
||||
|
||||
- Premium aesthetic — use the exact colors from the assigned theme preset (do NOT invent your own)
|
||||
- Background: use the preset's gradient-dark or gradient-light value directly as the AbsoluteFill's background
|
||||
- Card/surface backgrounds: use the preset's surface color
|
||||
- Text colors: use the preset's text, muted values
|
||||
- Borders: use the preset's border color
|
||||
- Glows: use the preset's glow value for all boxShadow — do NOT substitute other colors
|
||||
- Generous whitespace — less is more, let elements breathe
|
||||
- NO decorative background shapes, blurs, or overlapping ornaments
|
||||
|
||||
=== REMOTION RULES ===
|
||||
|
||||
- Export the component as: export const MyComposition = () => { ... }
|
||||
- Use useCurrentFrame() and useVideoConfig() from "remotion"
|
||||
- Do NOT use Sequence
|
||||
- Do NOT manually calculate animation timings or frame offsets
|
||||
|
||||
=== ANIMATION (use the stagger() helper for ALL element animations) ===
|
||||
|
||||
A pre-built helper function called stagger() is available globally.
|
||||
It handles enter, hold, and exit phases automatically — you MUST use it.
|
||||
|
||||
Signature:
|
||||
stagger(frame, fps, index, total) → { opacity: number, transform: string }
|
||||
|
||||
Parameters:
|
||||
frame — from useCurrentFrame()
|
||||
fps — from useVideoConfig()
|
||||
index — 0-based index of this element in the entrance order
|
||||
total — total number of animated elements in the scene
|
||||
|
||||
It returns a style object with opacity and transform that you spread onto the element.
|
||||
Timing is handled for you: staggered spring entrances, ambient hold motion, and a graceful exit.
|
||||
|
||||
Usage pattern:
|
||||
const frame = useCurrentFrame();
|
||||
const { fps } = useVideoConfig();
|
||||
|
||||
<div style={stagger(frame, fps, 0, 4)}>Headline</div>
|
||||
<div style={stagger(frame, fps, 1, 4)}>Subtitle</div>
|
||||
<div style={stagger(frame, fps, 2, 4)}>Card</div>
|
||||
<div style={stagger(frame, fps, 3, 4)}>Footer</div>
|
||||
|
||||
Rules:
|
||||
- Count ALL animated elements in your scene and pass that count as the "total" parameter.
|
||||
- Assign each element a sequential index starting from 0.
|
||||
- You can merge stagger's return with additional styles:
|
||||
<div style={{ ...stagger(frame, fps, 0, 3), fontSize: 64, color: "#fafafa" }}>
|
||||
- For non-animated static elements (backgrounds, borders), just use normal styles without stagger.
|
||||
- You may still use spring() and interpolate() for EXTRA custom effects (e.g., a number counter,
|
||||
color shift, or typewriter effect), but stagger() must drive all entrance/exit animations.
|
||||
|
||||
=== AVAILABLE GLOBALS (injected at runtime, do NOT import anything else) ===
|
||||
|
||||
- React (available globally)
|
||||
- AbsoluteFill, useCurrentFrame, useVideoConfig, spring, interpolate, Easing from "remotion"
|
||||
- stagger(frame, fps, index, total) — animation helper described above
|
||||
|
||||
=== CODE RULES ===
|
||||
|
||||
- Output ONLY the raw code, no markdown fences, no explanations
|
||||
- Keep it fully self-contained, no external dependencies or images
|
||||
- Use inline styles only (no CSS imports, no className)
|
||||
- Target 1920x1080 resolution
|
||||
- Every container must use display "flex" with explicit gap values
|
||||
- NEVER use marginTop/marginBottom to space siblings — use the parent's gap instead
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_scene_generation_user_prompt(
|
||||
slide_number: int,
|
||||
total_slides: int,
|
||||
title: str,
|
||||
subtitle: str,
|
||||
content_in_markdown: str,
|
||||
background_explanation: str,
|
||||
duration_in_frames: int,
|
||||
theme: str,
|
||||
mode: str,
|
||||
) -> str:
|
||||
"""Build the user prompt for generating a single slide's Remotion scene code.
|
||||
|
||||
*theme* and *mode* are pre-assigned (by LLM or fallback) before this is called.
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
"Create a cinematic, visually striking Remotion scene.",
|
||||
f"The video is {duration_in_frames} frames at {FPS}fps ({duration_in_frames / FPS:.1f}s total).",
|
||||
"",
|
||||
f"This is slide {slide_number} of {total_slides} in the video.",
|
||||
"",
|
||||
f"=== ASSIGNED THEME: {theme} / {mode.upper()} mode ===",
|
||||
f"You MUST use the {theme} preset in {mode} mode from the theme presets above.",
|
||||
f"Use its exact background gradient (gradient-{mode}), surface, text, accent, secondary, border, and glow colors.",
|
||||
"Do NOT substitute, invent, or default to blue/violet colors.",
|
||||
"",
|
||||
f'The scene should communicate this message: "{title} — {subtitle}"',
|
||||
"",
|
||||
"Key ideas to convey (use as creative inspiration, NOT literal text to dump on screen):",
|
||||
content_in_markdown,
|
||||
"",
|
||||
"Pick only the 1-2 most impactful phrases or numbers to display as text.",
|
||||
"",
|
||||
f"Mood & tone: {background_explanation}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
REFINE_SCENE_SYSTEM_PROMPT = """
|
||||
You are a code repair assistant. You will receive a Remotion React component that failed to compile,
|
||||
along with the exact error message from the Babel transpiler.
|
||||
|
||||
Your job is to fix the code so it compiles and runs correctly.
|
||||
|
||||
RULES:
|
||||
- Output ONLY the fixed raw code as a string — no markdown fences, no explanations.
|
||||
- Preserve the original intent, design, and animations as closely as possible.
|
||||
- The component must be exported as: export const MyComposition = () => { ... }
|
||||
- Only these globals are available at runtime (they are injected, not actually imported):
|
||||
React, AbsoluteFill, useCurrentFrame, useVideoConfig, spring, interpolate, Easing,
|
||||
stagger (a helper: stagger(frame, fps, index, total) → { opacity, transform })
|
||||
- Keep import statements at the top (they get stripped by the compiler) but do NOT import anything
|
||||
other than "react" and "remotion".
|
||||
- Use inline styles only (no CSS, no className).
|
||||
- Common fixes:
|
||||
- Mismatched braces/brackets in JSX style objects (e.g. }}, instead of }}>)
|
||||
- Missing closing tags
|
||||
- Trailing commas before > in JSX
|
||||
- Undefined variables or typos
|
||||
- Invalid JSX expressions
|
||||
- After fixing, mentally walk through every brace pair { } and JSX tag to verify they match.
|
||||
""".strip()
|
||||
73
surfsense_backend/app/agents/video_presentation/state.py
Normal file
73
surfsense_backend/app/agents/video_presentation/state.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Define the state structures for the video presentation agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class SlideContent(BaseModel):
|
||||
"""Represents a single parsed slide from content analysis."""
|
||||
|
||||
slide_number: int = Field(..., description="1-based slide number")
|
||||
title: str = Field(..., description="Concise slide title")
|
||||
subtitle: str = Field(..., description="One-line subtitle or tagline")
|
||||
content_in_markdown: str = Field(
|
||||
..., description="Slide body content formatted as markdown"
|
||||
)
|
||||
speaker_transcripts: list[str] = Field(
|
||||
...,
|
||||
description="2-4 short sentences a presenter would say while this slide is shown",
|
||||
)
|
||||
background_explanation: str = Field(
|
||||
...,
|
||||
description="Emotional mood and color direction for this slide",
|
||||
)
|
||||
|
||||
|
||||
class PresentationSlides(BaseModel):
|
||||
"""Represents the full set of parsed slides from the LLM."""
|
||||
|
||||
slides: list[SlideContent] = Field(
|
||||
..., description="Ordered array of presentation slides"
|
||||
)
|
||||
|
||||
|
||||
class SlideAudioResult(BaseModel):
|
||||
"""Audio generation result for a single slide."""
|
||||
|
||||
slide_number: int
|
||||
audio_file: str = Field(..., description="Path to the per-slide audio file")
|
||||
duration_seconds: float = Field(..., description="Audio duration in seconds")
|
||||
duration_in_frames: int = Field(
|
||||
..., description="Audio duration in frames (at 30fps)"
|
||||
)
|
||||
|
||||
|
||||
class SlideSceneCode(BaseModel):
|
||||
"""Generated Remotion component code for a single slide."""
|
||||
|
||||
slide_number: int
|
||||
code: str = Field(
|
||||
..., description="Raw Remotion React component source code for this slide"
|
||||
)
|
||||
title: str = Field(..., description="Short title for the composition")
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""State for the video presentation agent graph.
|
||||
|
||||
Pipeline: parse slides → (TTS audio ∥ theme assignment) → generate Remotion code
|
||||
The frontend receives the slides + code + audio and handles compilation/rendering.
|
||||
"""
|
||||
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
|
||||
slides: list[SlideContent] | None = None
|
||||
slide_audio_results: list[SlideAudioResult] | None = None
|
||||
slide_theme_assignments: dict[int, tuple[str, str]] | None = None
|
||||
slide_scene_codes: list[SlideSceneCode] | None = None
|
||||
30
surfsense_backend/app/agents/video_presentation/utils.py
Normal file
30
surfsense_backend/app/agents/video_presentation/utils.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
def get_voice_for_provider(provider: str, speaker_id: int = 0) -> dict | str:
|
||||
"""
|
||||
Get the appropriate voice configuration based on the TTS provider.
|
||||
|
||||
Currently single-speaker only (speaker_id=0). Multi-speaker support
|
||||
will be added in a future iteration.
|
||||
|
||||
Args:
|
||||
provider: The TTS provider (e.g., "openai/tts-1", "vertex_ai/test")
|
||||
speaker_id: The ID of the speaker (default 0, single speaker for now)
|
||||
|
||||
Returns:
|
||||
Voice configuration - string for OpenAI, dict for Vertex AI
|
||||
"""
|
||||
if provider == "local/kokoro":
|
||||
return "af_heart"
|
||||
|
||||
provider_type = (
|
||||
provider.split("/")[0].lower() if "/" in provider else provider.lower()
|
||||
)
|
||||
|
||||
voices = {
|
||||
"openai": "alloy",
|
||||
"vertex_ai": {
|
||||
"languageCode": "en-US",
|
||||
"name": "en-US-Studio-O",
|
||||
},
|
||||
"azure": "alloy",
|
||||
}
|
||||
return voices.get(provider_type, {})
|
||||
|
|
@ -341,7 +341,7 @@ if config.NEXT_FRONTEND_URL:
|
|||
allowed_origins.append(www_url)
|
||||
|
||||
allowed_origins.extend(
|
||||
[ # For local development and desktop app
|
||||
[ # For local development and desktop app
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ celery_app = Celery(
|
|||
include=[
|
||||
"app.tasks.celery_tasks.document_tasks",
|
||||
"app.tasks.celery_tasks.podcast_tasks",
|
||||
"app.tasks.celery_tasks.video_presentation_tasks",
|
||||
"app.tasks.celery_tasks.connector_tasks",
|
||||
"app.tasks.celery_tasks.schedule_checker_task",
|
||||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
|
|
|
|||
|
|
@ -224,6 +224,9 @@ class Config:
|
|||
os.getenv("CONNECTOR_INDEXING_LOCK_TTL_SECONDS", str(8 * 60 * 60))
|
||||
)
|
||||
|
||||
# Platform web search (SearXNG)
|
||||
SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST")
|
||||
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
# Backend URL to override the http to https in the OAuth redirect URI
|
||||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
|
|
|
|||
|
|
@ -1,719 +0,0 @@
|
|||
"""
|
||||
Composio Gmail Connector Module.
|
||||
|
||||
Provides Gmail specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify as md
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGmailConnector(ComposioConnector):
|
||||
"""
|
||||
Gmail specific Composio connector.
|
||||
|
||||
Provides methods for listing messages, getting message details, and formatting
|
||||
Gmail messages from Gmail via Composio.
|
||||
"""
|
||||
|
||||
async def list_gmail_messages(
|
||||
self,
|
||||
query: str = "",
|
||||
max_results: int = 50,
|
||||
page_token: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None, int | None, str | None]:
|
||||
"""
|
||||
List Gmail messages via Composio with pagination support.
|
||||
|
||||
Args:
|
||||
query: Gmail search query.
|
||||
max_results: Maximum number of messages per page (default: 50).
|
||||
page_token: Optional pagination token for next page.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages list, next_page_token, result_size_estimate, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], None, None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_messages(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
async def get_gmail_message_detail(
|
||||
self, message_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Get full details of a Gmail message via Composio.
|
||||
|
||||
Args:
|
||||
message_id: Gmail message ID.
|
||||
|
||||
Returns:
|
||||
Tuple of (message details, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return None, "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_gmail_message_detail(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_markdown(html: str) -> str:
|
||||
"""Convert HTML (especially email layouts with nested tables) to clean markdown."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for tag in soup.find_all(["style", "script", "img"]):
|
||||
tag.decompose()
|
||||
for tag in soup.find_all(
|
||||
["table", "thead", "tbody", "tfoot", "tr", "td", "th"]
|
||||
):
|
||||
tag.unwrap()
|
||||
return md(str(soup)).strip()
|
||||
|
||||
def format_gmail_message_to_markdown(self, message: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Gmail message to markdown.
|
||||
|
||||
Args:
|
||||
message: Message object from Composio's GMAIL_FETCH_EMAILS response.
|
||||
Composio structure: messageId, messageText, messageTimestamp,
|
||||
payload.headers, labelIds, attachmentList
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase)
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
label_ids = message.get("labelIds", [])
|
||||
|
||||
# Extract headers from payload
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
# Parse headers into a dict
|
||||
header_dict = {}
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
header_dict[name] = value
|
||||
|
||||
# Extract key information
|
||||
subject = header_dict.get("subject", "No Subject")
|
||||
from_email = header_dict.get("from", "Unknown Sender")
|
||||
to_email = header_dict.get("to", "Unknown Recipient")
|
||||
# Composio provides messageTimestamp directly
|
||||
date_str = message.get("messageTimestamp", "") or header_dict.get(
|
||||
"date", "Unknown Date"
|
||||
)
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {subject}\n\n"
|
||||
markdown_content += f"**From:** {from_email}\n"
|
||||
markdown_content += f"**To:** {to_email}\n"
|
||||
markdown_content += f"**Date:** {date_str}\n"
|
||||
|
||||
if label_ids:
|
||||
markdown_content += f"**Labels:** {', '.join(label_ids)}\n"
|
||||
|
||||
markdown_content += "\n---\n\n"
|
||||
|
||||
# Composio provides full message text in 'messageText' which is often raw HTML
|
||||
message_text = message.get("messageText", "")
|
||||
if message_text:
|
||||
message_text = self._html_to_markdown(message_text)
|
||||
markdown_content += f"## Content\n\n{message_text}\n\n"
|
||||
else:
|
||||
# Fallback to snippet if no messageText
|
||||
snippet = message.get("snippet", "")
|
||||
if snippet:
|
||||
markdown_content += f"## Preview\n\n{snippet}\n\n"
|
||||
|
||||
# Add attachment info if present
|
||||
attachments = message.get("attachmentList", [])
|
||||
if attachments:
|
||||
markdown_content += "## Attachments\n\n"
|
||||
for att in attachments:
|
||||
att_name = att.get("filename", att.get("name", "Unknown"))
|
||||
markdown_content += f"- {att_name}\n"
|
||||
markdown_content += "\n"
|
||||
|
||||
# Add message metadata
|
||||
markdown_content += "## Message Details\n\n"
|
||||
markdown_content += f"- **Message ID:** {message_id}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting message to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def _analyze_gmail_messages_phase1(
|
||||
session: AsyncSession,
|
||||
messages: list[dict[str, Any]],
|
||||
composio_connector: ComposioGmailConnector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> tuple[list[dict[str, Any]], int, int]:
|
||||
"""
|
||||
Phase 1: Analyze all messages, create pending documents.
|
||||
Makes ALL documents visible in the UI immediately with pending status.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages_to_process, documents_skipped, duplicate_content_count)
|
||||
"""
|
||||
messages_to_process = []
|
||||
documents_skipped = 0
|
||||
duplicate_content_count = 0
|
||||
|
||||
for message in messages:
|
||||
try:
|
||||
# Composio uses 'messageId' (camelCase), not 'id'
|
||||
message_id = message.get("messageId", "") or message.get("id", "")
|
||||
if not message_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Extract message info from Composio response
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
|
||||
subject = "No Subject"
|
||||
sender = "Unknown Sender"
|
||||
date_str = message.get("messageTimestamp", "Unknown Date")
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name == "subject":
|
||||
subject = value
|
||||
elif name == "from":
|
||||
sender = value
|
||||
elif name == "date":
|
||||
date_str = value
|
||||
|
||||
# Format to markdown using the full message data
|
||||
markdown_content = composio_connector.format_gmail_message_to_markdown(
|
||||
message
|
||||
)
|
||||
|
||||
# Check for empty content
|
||||
if not markdown_content.strip():
|
||||
logger.warning(f"Skipping Gmail message with no content: {subject}")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"gmail_{message_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Get label IDs and thread_id from Composio response
|
||||
label_ids = message.get("labelIds", [])
|
||||
thread_id = message.get("threadId", "") or message.get("thread_id", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Message {subject} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=subject,
|
||||
document_type=DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["gmail"]),
|
||||
document_metadata={
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date": date_str,
|
||||
"labels": label_ids,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "gmail",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
|
||||
messages_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"subject": subject,
|
||||
"sender": sender,
|
||||
"date_str": date_str,
|
||||
"label_ids": label_ids,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
return messages_to_process, documents_skipped, duplicate_content_count
|
||||
|
||||
|
||||
async def _process_gmail_messages_phase2(
|
||||
session: AsyncSession,
|
||||
messages_to_process: list[dict[str, Any]],
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Phase 2: Process each document one by one.
|
||||
Each document transitions: pending → processing → ready/failed
|
||||
|
||||
Returns:
|
||||
Tuple of (documents_indexed, documents_failed)
|
||||
"""
|
||||
documents_indexed = 0
|
||||
documents_failed = 0
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
for item in messages_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"document_type": "Gmail Message (Composio)",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
item["markdown_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["subject"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
"subject": item["subject"],
|
||||
"sender": item["sender"],
|
||||
"date": item["date_str"],
|
||||
"labels": item["label_ids"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Gmail messages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
return documents_indexed, documents_failed
|
||||
|
||||
|
||||
async def index_composio_gmail(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 1000,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Gmail messages via Composio with real-time document status updates."""
|
||||
try:
|
||||
composio_connector = ComposioGmailConnector(session, connector_id)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
if start_date is not None and end_date is not None:
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build query with date range
|
||||
query_parts = []
|
||||
if start_date_str:
|
||||
query_parts.append(f"after:{start_date_str.replace('-', '/')}")
|
||||
if end_date_str:
|
||||
query_parts.append(f"before:{end_date_str.replace('-', '/')}")
|
||||
query = " ".join(query_parts) if query_parts else ""
|
||||
|
||||
logger.info(
|
||||
f"Gmail query for connector {connector_id}: '{query}' "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Gmail messages via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_messages"},
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# FETCH ALL MESSAGES FIRST
|
||||
# =======================================================================
|
||||
batch_size = 50
|
||||
page_token = None
|
||||
all_messages = []
|
||||
result_size_estimate = None
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
while len(all_messages) < max_items:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(len(all_messages))
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
remaining = max_items - len(all_messages)
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
|
||||
(
|
||||
messages,
|
||||
next_token,
|
||||
result_size_estimate_batch,
|
||||
error,
|
||||
) = await composio_connector.list_gmail_messages(
|
||||
query=query,
|
||||
max_results=current_batch_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Gmail messages: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Gmail messages: {error}"
|
||||
|
||||
if not messages:
|
||||
break
|
||||
|
||||
if result_size_estimate is None and result_size_estimate_batch is not None:
|
||||
result_size_estimate = result_size_estimate_batch
|
||||
logger.info(
|
||||
f"Gmail API estimated {result_size_estimate} total messages for query: '{query}'"
|
||||
)
|
||||
|
||||
all_messages.extend(messages)
|
||||
logger.info(
|
||||
f"Fetched {len(messages)} messages (total: {len(all_messages)})"
|
||||
)
|
||||
|
||||
if not next_token or len(messages) < current_batch_size:
|
||||
break
|
||||
|
||||
page_token = next_token
|
||||
|
||||
if not all_messages:
|
||||
success_msg = "No Gmail messages found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"messages_count": 0}
|
||||
)
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(all_messages)} Gmail messages to index via Composio")
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all messages, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 1: Creating pending documents for {len(all_messages)} messages",
|
||||
{"stage": "phase1_pending"},
|
||||
)
|
||||
|
||||
(
|
||||
messages_to_process,
|
||||
documents_skipped,
|
||||
duplicate_content_count,
|
||||
) = await _analyze_gmail_messages_phase1(
|
||||
session=session,
|
||||
messages=all_messages,
|
||||
composio_connector=composio_connector,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
new_documents_count = len([m for m in messages_to_process if m["is_new"]])
|
||||
if new_documents_count > 0:
|
||||
logger.info(f"Phase 1: Committing {new_documents_count} pending documents")
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Phase 2: Processing {len(messages_to_process)} documents",
|
||||
{"stage": "phase2_processing"},
|
||||
)
|
||||
|
||||
documents_indexed, documents_failed = await _process_gmail_messages_phase2(
|
||||
session=session,
|
||||
messages_to_process=messages_to_process,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
# CRITICAL: Always update timestamp so Electric SQL syncs
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted
|
||||
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Gmail document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Gmail indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Gmail indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index Gmail via Composio: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Gmail via Composio: {e!s}"
|
||||
|
|
@ -1,566 +0,0 @@
|
|||
"""
|
||||
Composio Google Calendar Connector Module.
|
||||
|
||||
Provides Google Calendar specific methods for data retrieval and indexing via Composio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
calculate_date_range,
|
||||
check_duplicate_document_by_hash,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
# Heartbeat configuration
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""Get the current timestamp with timezone for updated_at field."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def check_document_by_unique_identifier(
|
||||
session: AsyncSession, unique_identifier_hash: str
|
||||
) -> Document | None:
|
||||
"""Check if a document with the given unique identifier hash already exists."""
|
||||
existing_doc_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.where(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
return existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
update_last_indexed: bool = True,
|
||||
) -> None:
|
||||
"""Update the last_indexed_at timestamp for a connector."""
|
||||
if update_last_indexed:
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
|
||||
class ComposioGoogleCalendarConnector(ComposioConnector):
|
||||
"""
|
||||
Google Calendar specific Composio connector.
|
||||
|
||||
Provides methods for listing calendar events and formatting them from
|
||||
Google Calendar via Composio.
|
||||
"""
|
||||
|
||||
async def list_calendar_events(
|
||||
self,
|
||||
time_min: str | None = None,
|
||||
time_max: str | None = None,
|
||||
max_results: int = 250,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
List Google Calendar events via Composio.
|
||||
|
||||
Args:
|
||||
time_min: Start time (RFC3339 format).
|
||||
time_max: End time (RFC3339 format).
|
||||
max_results: Maximum number of events.
|
||||
|
||||
Returns:
|
||||
Tuple of (events list, error message).
|
||||
"""
|
||||
connected_account_id = await self.get_connected_account_id()
|
||||
if not connected_account_id:
|
||||
return [], "No connected account ID found"
|
||||
|
||||
entity_id = await self.get_entity_id()
|
||||
service = await self._get_service()
|
||||
return await service.get_calendar_events(
|
||||
connected_account_id=connected_account_id,
|
||||
entity_id=entity_id,
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
def format_calendar_event_to_markdown(self, event: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a Google Calendar event to markdown.
|
||||
|
||||
Args:
|
||||
event: Event object from Google Calendar API.
|
||||
|
||||
Returns:
|
||||
Formatted markdown string.
|
||||
"""
|
||||
try:
|
||||
# Extract basic event information
|
||||
summary = event.get("summary", "No Title")
|
||||
description = event.get("description", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
# Extract start and end times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
|
||||
# Format times for display
|
||||
def format_time(time_str: str) -> str:
|
||||
if not time_str:
|
||||
return "Unknown"
|
||||
try:
|
||||
if "T" in time_str:
|
||||
dt = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
return time_str
|
||||
except Exception:
|
||||
return time_str
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Extract attendees
|
||||
attendees = event.get("attendees", [])
|
||||
attendee_list = []
|
||||
for attendee in attendees:
|
||||
email = attendee.get("email", "")
|
||||
display_name = attendee.get("displayName", email)
|
||||
response_status = attendee.get("responseStatus", "")
|
||||
attendee_list.append(f"- {display_name} ({response_status})")
|
||||
|
||||
# Build markdown content
|
||||
markdown_content = f"# {summary}\n\n"
|
||||
markdown_content += f"**Start:** {start_formatted}\n"
|
||||
markdown_content += f"**End:** {end_formatted}\n"
|
||||
|
||||
if location:
|
||||
markdown_content += f"**Location:** {location}\n"
|
||||
|
||||
markdown_content += "\n"
|
||||
|
||||
if description:
|
||||
markdown_content += f"## Description\n\n{description}\n\n"
|
||||
|
||||
if attendee_list:
|
||||
markdown_content += "## Attendees\n\n"
|
||||
markdown_content += "\n".join(attendee_list)
|
||||
markdown_content += "\n\n"
|
||||
|
||||
# Add event metadata
|
||||
markdown_content += "## Event Details\n\n"
|
||||
markdown_content += f"- **Event ID:** {event.get('id', 'Unknown')}\n"
|
||||
markdown_content += f"- **Created:** {event.get('created', 'Unknown')}\n"
|
||||
markdown_content += f"- **Updated:** {event.get('updated', 'Unknown')}\n"
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting event to markdown: {e!s}"
|
||||
|
||||
|
||||
# ============ Indexer Functions ============
|
||||
|
||||
|
||||
async def index_composio_google_calendar(
|
||||
session: AsyncSession,
|
||||
connector,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
update_last_indexed: bool = True,
|
||||
max_items: int = 2500,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Index Google Calendar events via Composio."""
|
||||
try:
|
||||
composio_connector = ComposioGoogleCalendarConnector(session, connector_id)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching Google Calendar events via Composio for connector {connector_id}",
|
||||
{"stage": "fetching_events"},
|
||||
)
|
||||
|
||||
# Normalize date values - handle "undefined" strings from frontend
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Use provided dates directly if both are provided, otherwise calculate from last_indexed_at
|
||||
# This ensures user-selected dates are respected (matching non-Composio Calendar connector behavior)
|
||||
if start_date is not None and end_date is not None:
|
||||
# User provided both dates - use them directly
|
||||
start_date_str = start_date
|
||||
end_date_str = end_date
|
||||
else:
|
||||
# Calculate date range with defaults (uses last_indexed_at or 365 days back)
|
||||
# This ensures indexing works even when user doesn't specify dates
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
)
|
||||
|
||||
# Build time range for API call
|
||||
time_min = f"{start_date_str}T00:00:00Z"
|
||||
time_max = f"{end_date_str}T23:59:59Z"
|
||||
|
||||
logger.info(
|
||||
f"Google Calendar query for connector {connector_id}: "
|
||||
f"(start_date={start_date_str}, end_date={end_date_str})"
|
||||
)
|
||||
|
||||
events, error = await composio_connector.list_calendar_events(
|
||||
time_min=time_min,
|
||||
time_max=time_max,
|
||||
max_results=max_items,
|
||||
)
|
||||
|
||||
if error:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to fetch Calendar events: {error}", {}
|
||||
)
|
||||
return 0, f"Failed to fetch Calendar events: {error}"
|
||||
|
||||
if not events:
|
||||
success_msg = "No Google Calendar events found in the specified date range"
|
||||
await task_logger.log_task_success(
|
||||
log_entry, success_msg, {"events_count": 0}
|
||||
)
|
||||
# CRITICAL: Update timestamp even when no events found so Electric SQL syncs and UI shows indexed status
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
await session.commit()
|
||||
return (
|
||||
0,
|
||||
None,
|
||||
) # Return None (not error) when no items found - this is success with 0 items
|
||||
|
||||
logger.info(f"Found {len(events)} Google Calendar events to index via Composio")
|
||||
|
||||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track events that failed processing
|
||||
duplicate_content_count = (
|
||||
0 # Track events skipped due to duplicate content_hash
|
||||
)
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Analyze all events, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# =======================================================================
|
||||
events_to_process = [] # List of dicts with document and event data
|
||||
new_documents_created = False
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# Handle both standard Google API and potential Composio variations
|
||||
event_id = event.get("id", "") or event.get("eventId", "")
|
||||
summary = (
|
||||
event.get("summary", "") or event.get("title", "") or "No Title"
|
||||
)
|
||||
|
||||
if not event_id:
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Format to markdown
|
||||
markdown_content = composio_connector.format_calendar_event_to_markdown(
|
||||
event
|
||||
)
|
||||
|
||||
# Generate unique identifier
|
||||
document_type = DocumentType(TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"])
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
document_type, f"calendar_{event_id}", search_space_id
|
||||
)
|
||||
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
|
||||
# Extract event times
|
||||
start = event.get("start", {})
|
||||
end = event.get("end", {})
|
||||
start_time = start.get("dateTime") or start.get("date", "")
|
||||
end_time = end.get("dateTime") or end.get("date", "")
|
||||
location = event.get("location", "")
|
||||
|
||||
if existing_document:
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from standard connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = await check_duplicate_document_by_hash(
|
||||
session, content_hash
|
||||
)
|
||||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Event {summary} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=summary,
|
||||
document_type=DocumentType(
|
||||
TOOLKIT_TO_DOCUMENT_TYPE["googlecalendar"]
|
||||
),
|
||||
document_metadata={
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
"connector_id": connector_id,
|
||||
"toolkit_id": "googlecalendar",
|
||||
"source": "composio",
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
chunks=[], # Empty at creation - safe for async
|
||||
status=DocumentStatus.pending(), # Pending until processing starts
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
events_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"markdown_content": markdown_content,
|
||||
"content_hash": content_hash,
|
||||
"event_id": event_id,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
|
||||
|
||||
for item in events_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
document = item["document"]
|
||||
try:
|
||||
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
|
||||
document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
# Heavy processing (LLM, embeddings, chunks)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"document_type": "Google Calendar Event (Composio)",
|
||||
}
|
||||
(
|
||||
summary_content,
|
||||
summary_embedding,
|
||||
) = await generate_document_summary(
|
||||
item["markdown_content"],
|
||||
user_llm,
|
||||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["summary"]
|
||||
document.content = summary_content
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
"start_time": item["start_time"],
|
||||
"end_time": item["end_time"],
|
||||
"location": item["location"],
|
||||
"connector_id": connector_id,
|
||||
"source": "composio",
|
||||
}
|
||||
await safe_set_chunks(session, document, chunks)
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready()
|
||||
|
||||
documents_indexed += 1
|
||||
|
||||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Electric SQL syncs
|
||||
# This ensures the UI shows "Last indexed" instead of "Never indexed"
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit to ensure all documents are persisted (safety net)
|
||||
# This matches the pattern used in non-Composio Gmail indexer
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Google Calendar events processed"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully committed all Composio Google Calendar document changes to database"
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any remaining integrity errors gracefully (race conditions, etc.)
|
||||
if (
|
||||
"duplicate key value violates unique constraint" in str(e).lower()
|
||||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"This may occur if the same event was indexed by multiple connectors. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
# Don't fail the entire task - some documents may have been successfully indexed
|
||||
else:
|
||||
raise
|
||||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Google Calendar indexing via Composio for connector {connector_id}",
|
||||
{
|
||||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Composio Google Calendar indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True
|
||||
)
|
||||
return 0, f"Failed to index Google Calendar via Composio: {e!s}"
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.confluence_connector import ConfluenceConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.confluence_add_connector_route import refresh_confluence_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -190,7 +189,11 @@ class ConfluenceHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.confluence_add_connector_route import (
|
||||
refresh_confluence_token,
|
||||
)
|
||||
|
||||
connector = await refresh_confluence_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
@ -341,6 +344,61 @@ class ConfluenceHistoryConnector:
|
|||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def _make_api_request_with_method(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "GET",
|
||||
json_payload: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Make a request to the Confluence API with a specified HTTP method."""
|
||||
if not self._use_oauth:
|
||||
raise ValueError("Write operations require OAuth authentication")
|
||||
|
||||
token = await self._get_valid_token()
|
||||
base_url = await self._get_base_url()
|
||||
http_client = await self._get_client()
|
||||
|
||||
url = f"{base_url}/wiki/api/v2/{endpoint}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = await http_client.post(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = await http_client.put(
|
||||
url, headers=headers, json=json_payload, params=params
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = await http_client.delete(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await http_client.get(url, headers=headers, params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = {
|
||||
"status_code": e.response.status_code,
|
||||
"url": str(e.request.url),
|
||||
"response_text": e.response.text,
|
||||
}
|
||||
logger.error(f"Confluence API HTTP error: {error_detail}")
|
||||
raise Exception(
|
||||
f"Confluence API request failed (HTTP {e.response.status_code}): {e.response.text}"
|
||||
) from e
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Confluence API request error: {e!s}", exc_info=True)
|
||||
raise Exception(f"Confluence API request failed: {e!s}") from e
|
||||
|
||||
async def get_all_spaces(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all spaces from Confluence.
|
||||
|
|
@ -593,6 +651,65 @@ class ConfluenceHistoryConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching pages: {e!s}"
|
||||
|
||||
async def get_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Fetch a single page by ID with body content."""
|
||||
return await self._make_api_request(
|
||||
f"pages/{page_id}", params={"body-format": "storage"}
|
||||
)
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
space_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
parent_page_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Confluence page."""
|
||||
payload: dict[str, Any] = {
|
||||
"spaceId": space_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
if parent_page_id:
|
||||
payload["parentId"] = parent_page_id
|
||||
return await self._make_api_request_with_method(
|
||||
"pages", method="POST", json_payload=payload
|
||||
)
|
||||
|
||||
async def update_page(
|
||||
self,
|
||||
page_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
version_number: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page (requires version number)."""
|
||||
payload: dict[str, Any] = {
|
||||
"id": page_id,
|
||||
"title": title,
|
||||
"body": {
|
||||
"representation": "storage",
|
||||
"value": body,
|
||||
},
|
||||
"version": {
|
||||
"number": version_number,
|
||||
},
|
||||
"status": "current",
|
||||
}
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="PUT", json_payload=payload
|
||||
)
|
||||
|
||||
async def delete_page(self, page_id: str) -> dict[str, Any]:
|
||||
"""Delete a Confluence page."""
|
||||
return await self._make_api_request_with_method(
|
||||
f"pages/{page_id}", method="DELETE"
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client connection."""
|
||||
if self._http_client:
|
||||
|
|
|
|||
|
|
@ -52,44 +52,39 @@ class GoogleCalendarConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -110,7 +105,6 @@ class GoogleCalendarConnector:
|
|||
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
|
||||
)
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -119,7 +113,6 @@ class GoogleCalendarConnector:
|
|||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Encrypt sensitive fields
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
|
|
@ -143,7 +136,6 @@ class GoogleCalendarConnector:
|
|||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
|
|
@ -15,16 +16,24 @@ from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
|||
class GoogleDriveClient:
|
||||
"""Client for Google Drive API operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
credentials: "Credentials | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize Google Drive client.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Drive connector
|
||||
credentials: Pre-built credentials (e.g. from Composio). If None,
|
||||
credentials are loaded from the DB connector config.
|
||||
"""
|
||||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self.service = None
|
||||
|
||||
async def get_service(self):
|
||||
|
|
@ -41,7 +50,12 @@ class GoogleDriveClient:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = await get_valid_credentials(self.session, self.connector_id)
|
||||
if self._credentials:
|
||||
credentials = self._credentials
|
||||
else:
|
||||
credentials = await get_valid_credentials(
|
||||
self.session, self.connector_id
|
||||
)
|
||||
self.service = build("drive", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ async def download_and_process_file(
|
|||
task_logger: TaskLoggingService,
|
||||
log_entry: Log,
|
||||
connector_id: int | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[Any, str | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Download Google Drive file and process using Surfsense file processors.
|
||||
|
|
@ -95,6 +96,7 @@ async def download_and_process_file(
|
|||
},
|
||||
}
|
||||
# Include connector_id for de-indexing support
|
||||
connector_info["enable_summary"] = enable_summary
|
||||
if connector_id is not None:
|
||||
connector_info["connector_id"] = connector_id
|
||||
|
||||
|
|
|
|||
|
|
@ -81,44 +81,39 @@ class GoogleGmailConnector:
|
|||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
Google OAuth credentials
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If credential refresh fails
|
||||
|
||||
Supports both native OAuth (with refresh_token) and Composio-sourced
|
||||
credentials (with refresh_handler). For Composio credentials, validation
|
||||
and DB persistence are skipped since Composio manages its own tokens.
|
||||
"""
|
||||
if not all(
|
||||
[
|
||||
self._credentials.client_id,
|
||||
self._credentials.client_secret,
|
||||
self._credentials.refresh_token,
|
||||
]
|
||||
has_standard_refresh = bool(self._credentials.refresh_token)
|
||||
|
||||
if has_standard_refresh and not all(
|
||||
[self._credentials.client_id, self._credentials.client_secret]
|
||||
):
|
||||
raise ValueError(
|
||||
"Google OAuth credentials (client_id, client_secret, refresh_token) must be set"
|
||||
"Google OAuth credentials (client_id, client_secret) must be set"
|
||||
)
|
||||
|
||||
if self._credentials and not self._credentials.expired:
|
||||
return self._credentials
|
||||
|
||||
# Create credentials from refresh token
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
if has_standard_refresh:
|
||||
self._credentials = Credentials(
|
||||
token=self._credentials.token,
|
||||
refresh_token=self._credentials.refresh_token,
|
||||
token_uri=self._credentials.token_uri,
|
||||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
# Use connector_id if available, otherwise fall back to user_id query
|
||||
# Only persist refreshed token for native OAuth (Composio manages its own)
|
||||
if has_standard_refresh and self._session:
|
||||
if self._connector_id:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -138,12 +133,38 @@ class GoogleGmailConnector:
|
|||
raise RuntimeError(
|
||||
"GMAIL connector not found; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
creds_dict = json.loads(self._credentials.to_json())
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
)
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = (
|
||||
token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
)
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
connector.config = creds_dict
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Check if this is an invalid_grant error (token expired/revoked)
|
||||
if (
|
||||
"invalid_grant" in error_str.lower()
|
||||
or "token has been expired or revoked" in error_str.lower()
|
||||
|
|
|
|||
|
|
@ -167,14 +167,23 @@ class JiraConnector:
|
|||
# Use direct base URL (works for both OAuth and legacy)
|
||||
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
|
||||
|
||||
if method.upper() == "POST":
|
||||
method_upper = method.upper()
|
||||
if method_upper == "POST":
|
||||
response = requests.post(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "PUT":
|
||||
response = requests.put(
|
||||
url, headers=headers, json=json_payload, timeout=500
|
||||
)
|
||||
elif method_upper == "DELETE":
|
||||
response = requests.delete(url, headers=headers, params=params, timeout=500)
|
||||
else:
|
||||
response = requests.get(url, headers=headers, params=params, timeout=500)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.status_code in (200, 201, 204):
|
||||
if response.status_code == 204 or not response.text:
|
||||
return {"status": "success"}
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
|
|
@ -352,6 +361,91 @@ class JiraConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
def get_myself(self) -> dict[str, Any]:
|
||||
"""Fetch the current user's profile (health check)."""
|
||||
return self.make_api_request("myself")
|
||||
|
||||
def get_projects(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all projects the user has access to."""
|
||||
result = self.make_api_request("project/search")
|
||||
return result.get("values", [])
|
||||
|
||||
def get_issue_types(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all issue types."""
|
||||
return self.make_api_request("issuetype")
|
||||
|
||||
def get_priorities(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all priority levels."""
|
||||
return self.make_api_request("priority")
|
||||
|
||||
def get_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Fetch a single issue by ID or key."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}")
|
||||
|
||||
def create_issue(
|
||||
self,
|
||||
project_key: str,
|
||||
summary: str,
|
||||
issue_type: str = "Task",
|
||||
description: str | None = None,
|
||||
priority: str | None = None,
|
||||
assignee_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Jira issue."""
|
||||
fields: dict[str, Any] = {
|
||||
"project": {"key": project_key},
|
||||
"summary": summary,
|
||||
"issuetype": {"name": issue_type},
|
||||
}
|
||||
if description:
|
||||
fields["description"] = {
|
||||
"type": "doc",
|
||||
"version": 1,
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [{"type": "text", "text": description}],
|
||||
}
|
||||
],
|
||||
}
|
||||
if priority:
|
||||
fields["priority"] = {"name": priority}
|
||||
if assignee_id:
|
||||
fields["assignee"] = {"accountId": assignee_id}
|
||||
|
||||
return self.make_api_request(
|
||||
"issue", method="POST", json_payload={"fields": fields}
|
||||
)
|
||||
|
||||
def update_issue(
|
||||
self, issue_id_or_key: str, fields: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Jira issue fields."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}",
|
||||
method="PUT",
|
||||
json_payload={"fields": fields},
|
||||
)
|
||||
|
||||
def delete_issue(self, issue_id_or_key: str) -> dict[str, Any]:
|
||||
"""Delete a Jira issue."""
|
||||
return self.make_api_request(f"issue/{issue_id_or_key}", method="DELETE")
|
||||
|
||||
def get_transitions(self, issue_id_or_key: str) -> list[dict[str, Any]]:
|
||||
"""Get available transitions for an issue (for status changes)."""
|
||||
result = self.make_api_request(f"issue/{issue_id_or_key}/transitions")
|
||||
return result.get("transitions", [])
|
||||
|
||||
def transition_issue(
|
||||
self, issue_id_or_key: str, transition_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Transition an issue to a new status."""
|
||||
return self.make_api_request(
|
||||
f"issue/{issue_id_or_key}/transitions",
|
||||
method="POST",
|
||||
json_payload={"transition": {"id": transition_id}},
|
||||
)
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.future import select
|
|||
from app.config import config
|
||||
from app.connectors.jira_connector import JiraConnector
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
|
|
@ -184,7 +183,9 @@ class JiraHistoryConnector:
|
|||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
# Lazy import to avoid circular dependency
|
||||
from app.routes.jira_add_connector_route import refresh_jira_token
|
||||
|
||||
connector = await refresh_jira_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from notion_client import AsyncClient
|
||||
from notion_client.errors import APIResponseError
|
||||
from notion_markdown import to_notion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -834,106 +834,8 @@ class NotionHistoryConnector:
|
|||
return None
|
||||
|
||||
def _markdown_to_blocks(self, markdown: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert markdown content to Notion blocks.
|
||||
|
||||
This is a simple converter that handles basic markdown.
|
||||
For more complex markdown, consider using a proper markdown parser.
|
||||
|
||||
Args:
|
||||
markdown: Markdown content
|
||||
|
||||
Returns:
|
||||
List of Notion block objects
|
||||
"""
|
||||
blocks = []
|
||||
lines = markdown.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Heading 1
|
||||
if line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 2
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Heading 3
|
||||
elif line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet list
|
||||
elif line.startswith("- ") or line.startswith("* "):
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:]}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif match := re.match(r"^(\d+)\.\s+(.*)$", line):
|
||||
content = match.group(2) # Extract text after "number. "
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": content}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Regular paragraph
|
||||
else:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [{"type": "text", "text": {"content": line}}]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return blocks
|
||||
"""Convert markdown content to Notion blocks using notion-markdown."""
|
||||
return to_notion(markdown)
|
||||
|
||||
async def create_page(
|
||||
self, title: str, content: str, parent_page_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -63,6 +63,16 @@ class DocumentType(StrEnum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
# Native Google document types → their legacy Composio equivalents.
|
||||
# Old documents may still carry the Composio type until they are re-indexed;
|
||||
# search, browse, and indexing must transparently handle both.
|
||||
NATIVE_TO_LEGACY_DOCTYPE: dict[str, str] = {
|
||||
"GOOGLE_DRIVE_FILE": "COMPOSIO_GOOGLE_DRIVE_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR": "COMPOSIO_GMAIL_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR",
|
||||
}
|
||||
|
||||
|
||||
class SearchSourceConnectorType(StrEnum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
|
|
@ -103,6 +113,13 @@ class PodcastStatus(StrEnum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class VideoPresentationStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class DocumentStatus:
|
||||
"""
|
||||
Helper class for document processing status (stored as JSONB).
|
||||
|
|
@ -337,6 +354,12 @@ class Permission(StrEnum):
|
|||
PODCASTS_UPDATE = "podcasts:update"
|
||||
PODCASTS_DELETE = "podcasts:delete"
|
||||
|
||||
# Video Presentations
|
||||
VIDEO_PRESENTATIONS_CREATE = "video_presentations:create"
|
||||
VIDEO_PRESENTATIONS_READ = "video_presentations:read"
|
||||
VIDEO_PRESENTATIONS_UPDATE = "video_presentations:update"
|
||||
VIDEO_PRESENTATIONS_DELETE = "video_presentations:delete"
|
||||
|
||||
# Image Generations
|
||||
IMAGE_GENERATIONS_CREATE = "image_generations:create"
|
||||
IMAGE_GENERATIONS_READ = "image_generations:read"
|
||||
|
|
@ -403,6 +426,10 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.PODCASTS_CREATE.value,
|
||||
Permission.PODCASTS_READ.value,
|
||||
Permission.PODCASTS_UPDATE.value,
|
||||
# Video Presentations (no delete)
|
||||
Permission.VIDEO_PRESENTATIONS_CREATE.value,
|
||||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
Permission.VIDEO_PRESENTATIONS_UPDATE.value,
|
||||
# Image Generations (create and read, no delete)
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
|
|
@ -435,6 +462,8 @@ DEFAULT_ROLE_PERMISSIONS = {
|
|||
Permission.LLM_CONFIGS_READ.value,
|
||||
# Podcasts (read only)
|
||||
Permission.PODCASTS_READ.value,
|
||||
# Video Presentations (read only)
|
||||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
# Image Generations (read only)
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
# Connectors (read only)
|
||||
|
|
@ -693,7 +722,7 @@ class ChatComment(BaseModel, TimestampMixin):
|
|||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Denormalized thread_id for efficient Electric SQL subscriptions (one per thread)
|
||||
# Denormalized thread_id for efficient Zero subscriptions (one per thread)
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
|
|
@ -763,7 +792,7 @@ class ChatCommentMention(BaseModel, TimestampMixin):
|
|||
class ChatSessionState(BaseModel):
|
||||
"""
|
||||
Tracks real-time session state for shared chat collaboration.
|
||||
One record per thread, synced via Electric SQL.
|
||||
One record per thread, synced via Zero.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_session_state"
|
||||
|
|
@ -1044,6 +1073,46 @@ class Podcast(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class VideoPresentation(BaseModel, TimestampMixin):
|
||||
"""Video presentation model for storing AI-generated video presentations.
|
||||
|
||||
The slides JSONB stores per-slide data including Remotion component code,
|
||||
audio file paths, and durations. The frontend compiles the code and renders
|
||||
the video using Remotion Player.
|
||||
"""
|
||||
|
||||
__tablename__ = "video_presentations"
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
slides = Column(JSONB, nullable=True)
|
||||
scene_codes = Column(JSONB, nullable=True)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
VideoPresentationStatus,
|
||||
name="video_presentation_status",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=VideoPresentationStatus.READY,
|
||||
server_default="ready",
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="video_presentations")
|
||||
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class Report(BaseModel, TimestampMixin):
|
||||
"""Report model for storing generated Markdown reports."""
|
||||
|
||||
|
|
@ -1228,6 +1297,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Podcast.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
video_presentations = relationship(
|
||||
"VideoPresentation",
|
||||
back_populates="search_space",
|
||||
order_by="VideoPresentation.id.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
reports = relationship(
|
||||
"Report",
|
||||
back_populates="search_space",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -157,7 +158,7 @@ class ChucksHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -217,18 +218,24 @@ class ChucksHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
@ -428,4 +435,4 @@ class ChucksHybridSearchRetriever:
|
|||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -149,7 +150,7 @@ class DocumentHybridSearchRetriever:
|
|||
query_text: str,
|
||||
top_k: int,
|
||||
search_space_id: int,
|
||||
document_type: str | None = None,
|
||||
document_type: str | list[str] | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
|
|
@ -197,18 +198,24 @@ class DocumentHybridSearchRetriever:
|
|||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
# Add document type filter if provided
|
||||
# Add document type filter if provided (single string or list of strings)
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
type_list = (
|
||||
document_type if isinstance(document_type, list) else [document_type]
|
||||
)
|
||||
doc_type_enums = []
|
||||
for dt in type_list:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if not doc_type_enums:
|
||||
return []
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
# Add time-based filtering if provided
|
||||
if start_date is not None:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from .search_spaces_routes import router as search_spaces_router
|
|||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -55,6 +56,9 @@ router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
|||
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(
|
||||
video_presentations_router
|
||||
) # Video presentation status and streaming
|
||||
router.include_router(reports_router) # Report CRUD and multi-format export
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(search_source_connectors_router)
|
||||
|
|
@ -76,7 +80,7 @@ router.include_router(model_list_router) # Dynamic LLM model catalogue from Ope
|
|||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Electric SQL sync
|
||||
router.include_router(notifications_router) # Notifications with Zero sync
|
||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ async def airtable_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=airtable_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -316,7 +316,7 @@ async def airtable_callback(
|
|||
f"Duplicate Airtable connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=airtable-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=airtable-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -348,7 +348,7 @@ async def airtable_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=airtable-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ async def clickup_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=clickup_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=clickup_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -326,7 +326,7 @@ async def clickup_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=clickup-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=clickup-connector"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ async def composio_callback(
|
|||
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=composio_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=composio_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -263,6 +263,15 @@ async def composio_callback(
|
|||
logger.info(
|
||||
f"Successfully got connected_account_id: {final_connected_account_id}"
|
||||
)
|
||||
# Wait for Composio to finish exchanging the auth code for tokens.
|
||||
try:
|
||||
service.wait_for_connection(final_connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for {final_connected_account_id}, "
|
||||
"proceeding anyway",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Build entity_id for Composio API calls (same format as used in initiate)
|
||||
entity_id = f"surfsense_{user_id}"
|
||||
|
|
@ -370,7 +379,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}"
|
||||
)
|
||||
|
||||
# This is a NEW account - create a new connector
|
||||
|
|
@ -399,7 +408,7 @@ async def composio_callback(
|
|||
toolkit_id, "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
@ -425,6 +434,211 @@ async def composio_callback(
|
|||
) from e
|
||||
|
||||
|
||||
COMPOSIO_CONNECTOR_TYPES = {
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth")
|
||||
async def reauth_composio_connector(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Initiate Composio re-authentication for an expired connected account.
|
||||
|
||||
Uses Composio's refresh API so the same connected_account_id stays valid
|
||||
after the user completes the OAuth flow again.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID the connector belongs to
|
||||
connector_id: ID of the existing Composio connector to re-authenticate
|
||||
return_url: Optional frontend path to redirect to after completion
|
||||
"""
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Composio integration is not enabled."
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type.in_(COMPOSIO_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Composio connector not found or access denied",
|
||||
)
|
||||
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Composio connected account ID not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Build callback URL with secure state
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
toolkit_id=connector.config.get("toolkit_id", ""),
|
||||
connector_id=connector_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
callback_base = config.COMPOSIO_REDIRECT_URI
|
||||
if not callback_base:
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
callback_base = (
|
||||
f"{backend_url}/api/v1/auth/composio/connector/reauth/callback"
|
||||
)
|
||||
else:
|
||||
# Replace the normal callback path with the reauth one
|
||||
callback_base = callback_base.replace(
|
||||
"/auth/composio/connector/callback",
|
||||
"/auth/composio/connector/reauth/callback",
|
||||
)
|
||||
|
||||
callback_url = f"{callback_base}?state={state_encoded}"
|
||||
|
||||
service = ComposioService()
|
||||
refresh_result = service.refresh_connected_account(
|
||||
connected_account_id=connected_account_id,
|
||||
redirect_url=callback_url,
|
||||
)
|
||||
|
||||
if refresh_result["redirect_url"] is None:
|
||||
# Token refreshed server-side; clear auth_expired immediately
|
||||
if connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Composio account {connected_account_id} refreshed server-side (no redirect needed)"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Authentication refreshed successfully.",
|
||||
}
|
||||
|
||||
logger.info(f"Initiating Composio re-auth for connector {connector_id}")
|
||||
return {"auth_url": refresh_result["redirect_url"]}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Composio re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/composio/connector/reauth/callback")
|
||||
async def composio_reauth_callback(
|
||||
request: Request,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Handle Composio re-authentication callback.
|
||||
|
||||
Clears the auth_expired flag and redirects the user back to the frontend.
|
||||
The connected_account_id has not changed — Composio refreshed it in place.
|
||||
"""
|
||||
try:
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
return_url = data.get("return_url")
|
||||
|
||||
if not reauth_connector_id:
|
||||
raise HTTPException(status_code=400, detail="Missing connector_id in state")
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth callback",
|
||||
)
|
||||
|
||||
# Wait for Composio to finish processing new tokens before proceeding.
|
||||
# Without this, get_access_token() may return stale credentials.
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if connected_account_id:
|
||||
try:
|
||||
service = ComposioService()
|
||||
service.wait_for_connection(connected_account_id, timeout=30.0)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"wait_for_connection timed out for connector {reauth_connector_id}, "
|
||||
"proceeding anyway — tokens may not be ready yet",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Clear auth_expired flag
|
||||
connector.config = {**connector.config, "auth_expired": False}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Composio re-auth completed for connector {reauth_connector_id}")
|
||||
|
||||
if return_url and return_url.startswith("/"):
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{return_url}")
|
||||
|
||||
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
|
||||
connector.config.get("toolkit_id", ""), "composio-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector={frontend_connector_id}&connectorId={reauth_connector_id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Composio reauth callback: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Composio re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/composio-drive/folders")
|
||||
async def list_composio_drive_folders(
|
||||
connector_id: int,
|
||||
|
|
@ -433,31 +647,23 @@ async def list_composio_drive_folders(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List folders AND files in user's Google Drive via Composio with hierarchical support.
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
This is called at index time from the manage connector page to display
|
||||
the complete file system (folders and files). Only folders are selectable.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the Composio Google Drive connector
|
||||
parent_id: Optional parent folder ID to list contents (None for root)
|
||||
|
||||
Returns:
|
||||
JSON with list of items: {
|
||||
"items": [
|
||||
{"id": str, "name": str, "mimeType": str, "isFolder": bool, ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Composio integration is not enabled.",
|
||||
)
|
||||
|
||||
connector = None
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
|
|
@ -474,7 +680,6 @@ async def list_composio_drive_folders(
|
|||
detail="Composio Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
# Get Composio connected account ID from config
|
||||
composio_connected_account_id = connector.config.get(
|
||||
"composio_connected_account_id"
|
||||
)
|
||||
|
|
@ -484,63 +689,43 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
# Initialize Composio service and fetch files
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
|
||||
# Fetch files/folders from Composio Google Drive
|
||||
files, _next_token, error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_size=100,
|
||||
)
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
logger.error(f"Failed to list Composio Drive files: {error}")
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Transform files to match the expected format with isFolder field
|
||||
items = []
|
||||
for file_info in files:
|
||||
file_id = file_info.get("id", "") or file_info.get("fileId", "")
|
||||
file_name = (
|
||||
file_info.get("name", "") or file_info.get("fileName", "") or "Untitled"
|
||||
)
|
||||
mime_type = file_info.get("mimeType", "") or file_info.get("mime_type", "")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
|
||||
is_folder = mime_type == "application/vnd.google-apps.folder"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"name": file_name,
|
||||
"mimeType": mime_type,
|
||||
"isFolder": is_folder,
|
||||
"parents": file_info.get("parents", []),
|
||||
"size": file_info.get("size"),
|
||||
"iconLink": file_info.get("iconLink"),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort: folders first, then files, both alphabetically
|
||||
folders = sorted(
|
||||
[item for item in items if item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
files_list = sorted(
|
||||
[item for item in items if not item["isFolder"]],
|
||||
key=lambda x: x["name"].lower(),
|
||||
)
|
||||
items = folders + files_list
|
||||
|
||||
folder_count = len(folders)
|
||||
file_count = len(files_list)
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
logger.info(
|
||||
f"Listed {len(items)} total items ({folder_count} folders, {file_count} files) for Composio connector {connector_id}"
|
||||
|
|
@ -553,6 +738,31 @@ async def list_composio_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Composio Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
or "401" in str(e)
|
||||
):
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Marked Composio connector {connector_id} as auth_expired"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ SCOPES = [
|
|||
"read:space:confluence",
|
||||
"read:page:confluence",
|
||||
"read:comment:confluence",
|
||||
"write:page:confluence", # Required for creating/updating pages
|
||||
"delete:page:confluence", # Required for deleting pages
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -170,7 +172,7 @@ async def confluence_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=confluence_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=confluence_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -196,6 +198,8 @@ async def confluence_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.CONFLUENCE_REDIRECT_URI:
|
||||
|
|
@ -292,6 +296,46 @@ async def confluence_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Confluence connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR, connector_config
|
||||
|
|
@ -310,7 +354,7 @@ async def confluence_callback(
|
|||
f"Duplicate Confluence connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=confluence-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=confluence-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -341,7 +385,7 @@ async def confluence_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=confluence-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -372,6 +416,73 @@ async def confluence_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/confluence/connector/reauth")
|
||||
async def reauth_confluence(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Confluence connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.CONFLUENCE_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Confluence re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Confluence re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Confluence re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_confluence_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ async def discord_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=discord_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=discord_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -311,7 +311,7 @@ async def discord_callback(
|
|||
f"Duplicate Discord connector detected for user {user_id} with server {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=discord-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=discord-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -342,7 +342,7 @@ async def discord_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=discord-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ async def create_documents_file_upload(
|
|||
Upload files as documents with real-time status tracking.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately via ElectricSQL)
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately via Zero)
|
||||
- Phase 2: Celery processes each file: pending → processing → ready/failed
|
||||
|
||||
Requires DOCUMENTS_CREATE permission.
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/calendar.events"]
|
||||
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -111,6 +113,66 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/reauth")
|
||||
async def reauth_calendar(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Google Calendar re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Google Calendar connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Calendar re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Calendar re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Calendar re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/callback")
|
||||
async def calendar_callback(
|
||||
request: Request,
|
||||
|
|
@ -137,7 +199,7 @@ async def calendar_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_calendar_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -197,6 +259,42 @@ async def calendar_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Calendar connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -210,7 +308,7 @@ async def calendar_callback(
|
|||
f"Duplicate Google Calendar connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-calendar-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-calendar-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -236,7 +334,7 @@ async def calendar_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-calendar-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ async def drive_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_drive_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -345,6 +345,7 @@ async def drive_callback(
|
|||
db_connector.config = {
|
||||
**creds_dict,
|
||||
"start_page_token": existing_start_page_token,
|
||||
"auth_expired": False,
|
||||
}
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
|
|
@ -360,7 +361,7 @@ async def drive_callback(
|
|||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
|
|
@ -375,7 +376,7 @@ async def drive_callback(
|
|||
f"Duplicate Google Drive connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-drive-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-drive-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -425,7 +426,7 @@ async def drive_callback(
|
|||
)
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -502,11 +503,35 @@ async def list_google_drive_folders(
|
|||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if (
|
||||
"401" in error
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list folder contents: {error}"
|
||||
)
|
||||
|
||||
# Count folders and files for better logging
|
||||
folder_count = sum(1 for item in items if item.get("isFolder", False))
|
||||
file_count = len(items) - folder_count
|
||||
|
||||
|
|
@ -515,7 +540,6 @@ async def list_google_drive_folders(
|
|||
+ (f" in folder {parent_id}" if parent_id else " in ROOT")
|
||||
)
|
||||
|
||||
# Log first few items for debugging
|
||||
if items:
|
||||
logger.info(f"First 3 items: {[item.get('name') for item in items[:3]]}")
|
||||
|
||||
|
|
@ -525,6 +549,31 @@ async def list_google_drive_folders(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing Drive contents: {e!s}", exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if (
|
||||
"401" in str(e)
|
||||
or "invalid_grant" in error_lower
|
||||
or "token has been expired or revoked" in error_lower
|
||||
or "invalid credentials" in error_lower
|
||||
or "authentication failed" in error_lower
|
||||
):
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
logger.info(f"Marked connector {connector_id} as auth_expired")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired for connector {connector_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google Drive authentication expired. Please re-authenticate.",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list Drive contents: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from fastapi.responses import RedirectResponse
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
|
|
@ -71,7 +73,7 @@ def get_google_flow():
|
|||
}
|
||||
},
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
|
|
@ -129,6 +131,66 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/reauth")
|
||||
async def reauth_gmail(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Gmail re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Gmail connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Gmail re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Gmail re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Gmail re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/callback")
|
||||
async def gmail_callback(
|
||||
request: Request,
|
||||
|
|
@ -168,7 +230,7 @@ async def gmail_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=google_gmail_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -228,6 +290,42 @@ async def gmail_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {**creds_dict}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Gmail connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -241,7 +339,7 @@ async def gmail_callback(
|
|||
f"Duplicate Gmail connector detected for user {user_id} with email {user_email}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=google-gmail-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=google-gmail-connector"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -272,7 +370,7 @@ async def gmail_callback(
|
|||
# Redirect to the frontend with success params for indexing config
|
||||
# Using query params to auto-open the popup with config view on new-chat page
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=google-gmail-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ ACCESSIBLE_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-res
|
|||
SCOPES = [
|
||||
"read:jira-work",
|
||||
"read:jira-user",
|
||||
"write:jira-work", # Required for creating/updating/deleting issues
|
||||
"offline_access", # Required for refresh tokens
|
||||
]
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ async def jira_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=jira_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=jira_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -193,6 +194,8 @@ async def jira_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.JIRA_REDIRECT_URI:
|
||||
|
|
@ -310,6 +313,46 @@ async def jira_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication: update existing connector instead of creating new one
|
||||
if reauth_connector_id:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
result = await session.execute(
|
||||
sa_select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Jira connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}?reauth=success&connector=jira-connector"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?reauth=success&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR, connector_config
|
||||
|
|
@ -328,7 +371,7 @@ async def jira_callback(
|
|||
f"Duplicate Jira connector detected for user {user_id} with instance {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=jira-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=jira-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -359,7 +402,7 @@ async def jira_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=jira-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -390,6 +433,73 @@ async def jira_callback(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/jira/connector/reauth")
|
||||
async def reauth_jira(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Jira connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": config.ATLASSIAN_CLIENT_ID,
|
||||
"scope": " ".join(SCOPES),
|
||||
"redirect_uri": config.JIRA_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Jira re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Jira re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Jira re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_jira_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import fetch_linear_organization_name
|
||||
|
|
@ -127,6 +129,70 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/reauth")
|
||||
async def reauth_linear(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Linear re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Linear connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.LINEAR_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Linear OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.LINEAR_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.LINEAR_REDIRECT_URI,
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Linear re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Linear re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Linear re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/linear/connector/callback")
|
||||
async def linear_callback(
|
||||
request: Request,
|
||||
|
|
@ -166,7 +232,7 @@ async def linear_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=linear_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -267,6 +333,43 @@ async def linear_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
connector_config["organization_name"] = org_name
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Linear connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Check for duplicate connector (same organization already connected)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
|
|
@ -280,7 +383,7 @@ async def linear_callback(
|
|||
f"Duplicate Linear connector detected for user {user_id} with org {org_name}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=linear-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=linear-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -292,6 +395,7 @@ async def linear_callback(
|
|||
org_name,
|
||||
)
|
||||
# Create new connector
|
||||
connector_config["organization_name"] = org_name
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
|
|
@ -311,7 +415,7 @@ async def linear_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=linear-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -342,6 +446,22 @@ async def linear_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_linear_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -375,6 +495,7 @@ async def refresh_linear_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -417,6 +538,7 @@ async def refresh_linear_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Linear authentication failed. Please re-authenticate.",
|
||||
|
|
@ -453,10 +575,16 @@ async def refresh_linear_token(
|
|||
credentials.expires_at = expires_at
|
||||
credentials.scope = token_json.get("scope")
|
||||
|
||||
# Update connector config with encrypted tokens
|
||||
# Update connector config with encrypted tokens, preserving non-credential fields
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
if connector.config.get("organization_name"):
|
||||
credentials_dict["organization_name"] = connector.config[
|
||||
"organization_name"
|
||||
]
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Notifications API routes.
|
||||
These endpoints allow marking notifications as read and fetching older notifications.
|
||||
Electric SQL automatically syncs the changes to all connected clients for recent items.
|
||||
Zero automatically syncs the changes to all connected clients for recent items.
|
||||
For older items (beyond the sync window), use the list endpoint.
|
||||
"""
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ async def get_unread_count(
|
|||
|
||||
This allows the frontend to calculate:
|
||||
- older_unread = total_unread - recent_unread (static until reconciliation)
|
||||
- Display count = older_unread + live_recent_count (from Electric SQL)
|
||||
- Display count = older_unread + live_recent_count (from Zero)
|
||||
"""
|
||||
# Calculate cutoff date for sync window
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
|
@ -344,7 +344,7 @@ async def list_notifications(
|
|||
List notifications for the current user with pagination.
|
||||
|
||||
This endpoint is used as a fallback for older notifications that are
|
||||
outside the Electric SQL sync window (2 weeks).
|
||||
outside the Zero sync window (2 weeks).
|
||||
|
||||
Use `before_date` to paginate through older notifications efficiently.
|
||||
"""
|
||||
|
|
@ -487,7 +487,7 @@ async def mark_notification_as_read(
|
|||
"""
|
||||
Mark a single notification as read.
|
||||
|
||||
Electric SQL will automatically sync this change to all connected clients.
|
||||
Zero will automatically sync this change to all connected clients.
|
||||
"""
|
||||
# Verify the notification belongs to the user
|
||||
result = await session.execute(
|
||||
|
|
@ -528,7 +528,7 @@ async def mark_all_notifications_as_read(
|
|||
"""
|
||||
Mark all notifications as read for the current user.
|
||||
|
||||
Electric SQL will automatically sync these changes to all connected clients.
|
||||
Zero will automatically sync these changes to all connected clients.
|
||||
"""
|
||||
# Update all unread notifications for the user
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -124,6 +126,70 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/reauth")
|
||||
async def reauth_notion(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Notion re-authentication for an existing connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Notion connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.NOTION_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Notion OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.NOTION_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"owner": "user",
|
||||
"redirect_uri": config.NOTION_REDIRECT_URI,
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
f"Initiating Notion re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Notion re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Notion re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/notion/connector/callback")
|
||||
async def notion_callback(
|
||||
request: Request,
|
||||
|
|
@ -163,7 +229,7 @@ async def notion_callback(
|
|||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=notion_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
|
|
@ -266,6 +332,42 @@ async def notion_callback(
|
|||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Notion connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# Extract unique identifier from connector credentials
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR, connector_config
|
||||
|
|
@ -284,7 +386,7 @@ async def notion_callback(
|
|||
f"Duplicate Notion connector detected for user {user_id} with workspace {connector_identifier}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=duplicate_account&connector=notion-connector"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=notion-connector"
|
||||
)
|
||||
|
||||
# Generate a unique, user-friendly connector name
|
||||
|
|
@ -315,7 +417,7 @@ async def notion_callback(
|
|||
|
||||
# Redirect to the frontend with success params
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=notion-connector&connectorId={new_connector.id}"
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
|
@ -346,6 +448,22 @@ async def notion_callback(
|
|||
) from e
|
||||
|
||||
|
||||
async def _mark_connector_auth_expired(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> None:
|
||||
"""Persist auth_expired flag in the connector config so the frontend can show a re-auth prompt."""
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to persist auth_expired flag for connector {connector.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_notion_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
|
|
@ -379,6 +497,7 @@ async def refresh_notion_token(
|
|||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
|
|
@ -421,6 +540,7 @@ async def refresh_notion_token(
|
|||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
await _mark_connector_auth_expired(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Notion authentication failed. Please re-authenticate.",
|
||||
|
|
@ -469,7 +589,9 @@ async def refresh_notion_token(
|
|||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
credentials_dict.pop("auth_expired", None)
|
||||
connector.config = credentials_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue