Merge pull request #1030 from MODSetter/dev

feat: desktop quick-ask, parallel indexing, UI/UX fixes & agent rework
This commit is contained in:
Rohan Verma 2026-03-28 17:09:14 -07:00 committed by GitHub
commit c74b51745c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
178 changed files with 18379 additions and 9024 deletions

View file

@ -0,0 +1,441 @@
---
name: python-patterns
description: Python development principles and decision-making. Framework selection, async patterns, type hints, project structure. Teaches thinking, not copying.
allowed-tools: Read, Write, Edit, Glob, Grep
---
# Python Patterns
> Python development principles and decision-making for 2025.
> **Learn to THINK, not memorize patterns.**
---
## ⚠️ How to Use This Skill
This skill teaches **decision-making principles**, not fixed code to copy.
- ASK user for framework preference when unclear
- Choose async vs sync based on CONTEXT
- Don't default to same framework every time
---
## 1. Framework Selection (2025)
### Decision Tree
```
What are you building?
├── API-first / Microservices
│ └── FastAPI (async, modern, fast)
├── Full-stack web / CMS / Admin
│ └── Django (batteries-included)
├── Simple / Script / Learning
│ └── Flask (minimal, flexible)
├── AI/ML API serving
│ └── FastAPI (Pydantic, async, uvicorn)
└── Background workers
└── Celery + any framework
```
### Comparison Principles
| Factor | FastAPI | Django | Flask |
|--------|---------|--------|-------|
| **Best for** | APIs, microservices | Full-stack, CMS | Simple, learning |
| **Async** | Native | Django 5.0+ | Via extensions |
| **Admin** | Manual | Built-in | Via extensions |
| **ORM** | Choose your own | Django ORM | Choose your own |
| **Learning curve** | Low | Medium | Low |
### Selection Questions to Ask:
1. Is this API-only or full-stack?
2. Need admin interface?
3. Team familiar with async?
4. Existing infrastructure?
---
## 2. Async vs Sync Decision
### When to Use Async
```
async def is better when:
├── I/O-bound operations (database, HTTP, file)
├── Many concurrent connections
├── Real-time features
├── Microservices communication
└── FastAPI/Starlette/Django ASGI
def (sync) is better when:
├── CPU-bound operations
├── Simple scripts
├── Legacy codebase
├── Team unfamiliar with async
└── Blocking libraries (no async version)
```
### The Golden Rule
```
I/O-bound → async (waiting for external)
CPU-bound → sync + multiprocessing (computing)
Don't:
├── Mix sync and async carelessly
├── Use sync libraries in async code
└── Force async for CPU work
```
### Async Library Selection
| Need | Async Library |
|------|---------------|
| HTTP client | httpx |
| PostgreSQL | asyncpg |
| Redis | aioredis / redis-py async |
| File I/O | aiofiles |
| Database ORM | SQLAlchemy 2.0 async, Tortoise |
---
## 3. Type Hints Strategy
### When to Type
```
Always type:
├── Function parameters
├── Return types
├── Class attributes
├── Public APIs
Can skip:
├── Local variables (let inference work)
├── One-off scripts
├── Tests (usually)
```
### Common Type Patterns
```python
# These are patterns, understand them:
# Optional → might be None
from typing import Optional
def find_user(id: int) -> Optional[User]: ...
# Union → one of multiple types
def process(data: str | dict) -> None: ...
# Generic collections
def get_items() -> list[Item]: ...
def get_mapping() -> dict[str, int]: ...
# Callable
from typing import Callable
def apply(fn: Callable[[int], str]) -> str: ...
```
### Pydantic for Validation
```
When to use Pydantic:
├── API request/response models
├── Configuration/settings
├── Data validation
├── Serialization
Benefits:
├── Runtime validation
├── Auto-generated JSON schema
├── Works with FastAPI natively
└── Clear error messages
```
---
## 4. Project Structure Principles
### Structure Selection
```
Small project / Script:
├── main.py
├── utils.py
└── requirements.txt
Medium API:
├── app/
│ ├── __init__.py
│ ├── main.py
│ ├── models/
│ ├── routes/
│ ├── services/
│ └── schemas/
├── tests/
└── pyproject.toml
Large application:
├── src/
│ └── myapp/
│ ├── core/
│ ├── api/
│ ├── services/
│ ├── models/
│ └── ...
├── tests/
└── pyproject.toml
```
### FastAPI Structure Principles
```
Organize by feature or layer:
By layer:
├── routes/ (API endpoints)
├── services/ (business logic)
├── models/ (database models)
├── schemas/ (Pydantic models)
└── dependencies/ (shared deps)
By feature:
├── users/
│ ├── routes.py
│ ├── service.py
│ └── schemas.py
└── products/
└── ...
```
---
## 5. Django Principles (2025)
### Django Async (Django 5.0+)
```
Django supports async:
├── Async views
├── Async middleware
├── Async ORM (limited)
└── ASGI deployment
When to use async in Django:
├── External API calls
├── WebSocket (Channels)
├── High-concurrency views
└── Background task triggering
```
### Django Best Practices
```
Model design:
├── Fat models, thin views
├── Use managers for common queries
├── Abstract base classes for shared fields
Views:
├── Class-based for complex CRUD
├── Function-based for simple endpoints
├── Use viewsets with DRF
Queries:
├── select_related() for FKs
├── prefetch_related() for M2M
├── Avoid N+1 queries
└── Use .only() for specific fields
```
---
## 6. FastAPI Principles
### async def vs def in FastAPI
```
Use async def when:
├── Using async database drivers
├── Making async HTTP calls
├── I/O-bound operations
└── Want to handle concurrency
Use def when:
├── Blocking operations
├── Sync database drivers
├── CPU-bound work
└── FastAPI runs in threadpool automatically
```
### Dependency Injection
```
Use dependencies for:
├── Database sessions
├── Current user / Auth
├── Configuration
├── Shared resources
Benefits:
├── Testability (mock dependencies)
├── Clean separation
├── Automatic cleanup (yield)
```
### Pydantic v2 Integration
```python
# FastAPI + Pydantic are tightly integrated:
# Request validation
@app.post("/users")
async def create(user: UserCreate) -> UserResponse:
# user is already validated
...
# Response serialization
# Return type becomes response schema
```
---
## 7. Background Tasks
### Selection Guide
| Solution | Best For |
|----------|----------|
| **BackgroundTasks** | Simple, in-process tasks |
| **Celery** | Distributed, complex workflows |
| **ARQ** | Async, Redis-based |
| **RQ** | Simple Redis queue |
| **Dramatiq** | Actor-based, simpler than Celery |
### When to Use Each
```
FastAPI BackgroundTasks:
├── Quick operations
├── No persistence needed
├── Fire-and-forget
└── Same process
Celery/ARQ:
├── Long-running tasks
├── Need retry logic
├── Distributed workers
├── Persistent queue
└── Complex workflows
```
---
## 8. Error Handling Principles
### Exception Strategy
```
In FastAPI:
├── Create custom exception classes
├── Register exception handlers
├── Return consistent error format
└── Log without exposing internals
Pattern:
├── Raise domain exceptions in services
├── Catch and transform in handlers
└── Client gets clean error response
```
### Error Response Philosophy
```
Include:
├── Error code (programmatic)
├── Message (human readable)
├── Details (field-level when applicable)
└── NOT stack traces (security)
```
---
## 9. Testing Principles
### Testing Strategy
| Type | Purpose | Tools |
|------|---------|-------|
| **Unit** | Business logic | pytest |
| **Integration** | API endpoints | pytest + httpx/TestClient |
| **E2E** | Full workflows | pytest + DB |
### Async Testing
```python
# Use pytest-asyncio for async tests
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_endpoint():
async with AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/users")
assert response.status_code == 200
```
### Fixtures Strategy
```
Common fixtures:
├── db_session → Database connection
├── client → Test client
├── authenticated_user → User with token
└── sample_data → Test data setup
```
---
## 10. Decision Checklist
Before implementing:
- [ ] **Asked user about framework preference?**
- [ ] **Chosen framework for THIS context?** (not just default)
- [ ] **Decided async vs sync?**
- [ ] **Planned type hint strategy?**
- [ ] **Defined project structure?**
- [ ] **Planned error handling?**
- [ ] **Considered background tasks?**
---
## 11. Anti-Patterns to Avoid
### ❌ DON'T:
- Default to Django for simple APIs (FastAPI may be better)
- Use sync libraries in async code
- Skip type hints for public APIs
- Put business logic in routes/views
- Ignore N+1 queries
- Mix async and sync carelessly
### ✅ DO:
- Choose framework based on context
- Ask about async requirements
- Use Pydantic for validation
- Separate concerns (routes → services → repos)
- Test critical paths
---
> **Remember**: Python patterns are about decision-making for YOUR specific context. Don't copy code—think about what serves your application best.

View file

@ -0,0 +1,26 @@
{
"name": "python-patterns",
"description": "Python development principles and decision-making. Framework selection, async patterns, type hints, project structure. Teaches thinking, not copying.",
"category": "development",
"canonical_category": "development",
"repository": "majiayu000/claude-skill-registry-data",
"repository_url": "https://github.com/majiayu000/claude-skill-registry-data",
"author": "majiayu000",
"author_avatar": "https://github.com/majiayu000.png",
"file_path": "data/antigravity-python-patterns/SKILL.md",
"source": "github_curated_repos",
"stars": 2,
"quality_score": 64,
"best_practices_score": 60,
"skill_level": 3,
"skill_level_label": "resources",
"has_scripts": true,
"has_extra_files": false,
"downloads": 0,
"content_hash": "4dd862e5d25189c938320851ba3a0fcc6155f88b67f3822aaa4d0fac7e2d4659",
"indexed_at": "2026-03-01T03:09:48.327Z",
"synced_at": "2026-03-01T06:21:10.361Z",
"omni_registry_url": "https://omni-skill-registry.omniroute.online/#/skill/1d2fb797167c0af9a8694048782c1c3aab4b64b5b33cd055e1e7b702ede3eeb5",
"install_command": "mkdir -p .claude/skills/python-patterns && curl -sL \"https://raw.githubusercontent.com/majiayu000/claude-skill-registry-data/main/data/antigravity-python-patterns/SKILL.md\" > .claude/skills/python-patterns/SKILL.md",
"raw_url": "https://raw.githubusercontent.com/majiayu000/claude-skill-registry-data/main/data/antigravity-python-patterns/SKILL.md"
}

3
.gitignore vendored
View file

@ -5,4 +5,5 @@ node_modules/
.ruff_cache/
.venv
.pnpm-store
.DS_Store
.DS_Store
deepagents/

View file

@ -169,13 +169,3 @@ LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=lsv2_pt_.....
LANGSMITH_PROJECT=surfsense
# Agent Specific Configuration
# Daytona Sandbox (secure cloud code execution for deep agent)
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
DAYTONA_SANDBOX_ENABLED=TRUE
DAYTONA_API_KEY=dtn_asdasfasfafas
DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files

View file

@ -37,7 +37,9 @@ def upgrade() -> None:
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT 1 FROM information_schema.tables WHERE table_name = 'video_presentations'")
sa.text(
"SELECT 1 FROM information_schema.tables WHERE table_name = 'video_presentations'"
)
)
if not result.fetchone():
op.create_table(

View file

@ -0,0 +1,90 @@
"""Add folders table and folder_id to documents
Revision ID: 109
Revises: 108
Creates the folders table for nested folder organization (max 8 levels),
adds folder_id FK to documents, and creates an expression-based unique
index to correctly handle NULL parent_id at root level.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "109"
down_revision: str | None = "108"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table(
"folders",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(255), nullable=False, index=True),
sa.Column("position", sa.String(50), nullable=False, index=True),
sa.Column(
"parent_id",
sa.Integer(),
sa.ForeignKey("folders.id", ondelete="CASCADE"),
nullable=True,
index=True,
),
sa.Column(
"search_space_id",
sa.Integer(),
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column(
"created_by_id",
sa.Uuid(),
sa.ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
# Expression-based unique index: COALESCE(parent_id, 0) handles NULL correctly.
# PostgreSQL treats NULL != NULL in regular unique constraints, so a standard
# UniqueConstraint(search_space_id, parent_id, name) would allow duplicate
# folder names at the root level.
op.execute(
"""
CREATE UNIQUE INDEX uq_folder_space_parent_name
ON folders (search_space_id, COALESCE(parent_id, 0), name);
"""
)
op.add_column(
"documents",
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("folders.id", ondelete="SET NULL"),
nullable=True,
index=True,
),
)
def downgrade() -> None:
op.drop_column("documents", "folder_id")
op.execute("DROP INDEX IF EXISTS uq_folder_space_parent_name;")
op.drop_table("folders")

View file

@ -1,11 +1,12 @@
"""
SurfSense New Chat Agent Module.
This module provides the SurfSense deep agent with configurable tools
for knowledge base search, podcast generation, and more.
This module provides the SurfSense deep agent with configurable tools,
middleware, and preloaded knowledge-base filesystem behavior.
Directory Structure:
- tools/: All agent tools (knowledge_base, podcast, generate_image, etc.)
- tools/: All agent tools (podcast, generate_image, web, memory, etc.)
- middleware/: Custom middleware (knowledge search, filesystem, dedup, etc.)
- chat_deepagent.py: Main agent factory
- system_prompt.py: System prompts and instructions
- context.py: Context schema for the agent
@ -23,6 +24,13 @@ from .context import SurfSenseContextSchema
# LLM config
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
# Middleware
from .middleware import (
DedupHITLToolCallsMiddleware,
KnowledgeBaseSearchMiddleware,
SurfSenseFilesystemMiddleware,
)
# System prompt
from .system_prompt import (
SURFSENSE_CITATION_INSTRUCTIONS,
@ -39,7 +47,6 @@ from .tools import (
build_tools,
create_generate_podcast_tool,
create_scrape_webpage_tool,
create_search_knowledge_base_tool,
format_documents_for_context,
get_all_tool_names,
get_default_enabled_tools,
@ -53,8 +60,12 @@ __all__ = [
# System prompt
"SURFSENSE_CITATION_INSTRUCTIONS",
"SURFSENSE_SYSTEM_PROMPT",
# Middleware
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
# Context
"SurfSenseContextSchema",
"SurfSenseFilesystemMiddleware",
"ToolDefinition",
"build_surfsense_system_prompt",
"build_tools",
@ -63,7 +74,6 @@ __all__ = [
# Tool factories
"create_generate_podcast_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
# Agent factory
"create_surfsense_deep_agent",
# Knowledge base utilities

View file

@ -4,6 +4,13 @@ SurfSense deep agent implementation.
This module provides the factory function for creating SurfSense deep agents
with configurable tools via the tools registry and configurable prompts
via NewLLMConfig.
We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
(from deepagents) so that the middleware stack is fully under our control.
This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable
subclass of the default ``FilesystemMiddleware`` while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
summarisation, prompt-caching, etc.).
"""
import asyncio
@ -12,8 +19,15 @@ import time
from collections.abc import Sequence
from typing import Any
from deepagents import create_deep_agent
from deepagents.backends.protocol import SandboxBackendProtocol
from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_version
from deepagents.backends import StateBackend
from deepagents.graph import BASE_AGENT_PROMPT
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from deepagents.middleware.summarization import create_summarization_middleware
from langchain.agents import create_agent
from langchain.agents.middleware import TodoListMiddleware
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -21,8 +35,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware.dedup_tool_calls import (
from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
KnowledgeBaseSearchMiddleware,
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
@ -40,15 +56,15 @@ _perf_log = get_perf_logger()
# =============================================================================
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
# used by the knowledge_base and web_search tools.
# used by pre-search middleware and web_search.
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
# the web_search tool; all others go to search_knowledge_base.
# the web_search tool; all others are considered local/indexed data.
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
# Live search connectors (handled by web_search tool)
"TAVILY_API": "TAVILY_API",
"LINKUP_API": "LINKUP_API",
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
# Local/indexed connectors (handled by search_knowledge_base tool)
# Local/indexed connectors (handled by KB pre-search middleware)
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
@ -141,13 +157,11 @@ async def create_surfsense_deep_agent(
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_backend: SandboxBackendProtocol | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
The agent comes with built-in tools that can be configured:
- search_knowledge_base: Search the user's personal knowledge base
- generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models
- scrape_webpage: Extract content from webpages
@ -179,9 +193,6 @@ async def create_surfsense_deep_agent(
These are always added regardless of enabled/disabled settings.
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
secure code execution. When provided, the agent gets an
isolated ``execute`` tool for running shell commands.
Returns:
CompiledStateGraph: The configured deep agent
@ -205,7 +216,7 @@ async def create_surfsense_deep_agent(
# Create agent with only specific tools
agent = create_surfsense_deep_agent(
llm, search_space_id, db_session, ...,
enabled_tools=["search_knowledge_base", "scrape_webpage"]
enabled_tools=["scrape_webpage"]
)
# Create agent without podcast generation
@ -357,6 +368,10 @@ async def create_surfsense_deep_agent(
]
modified_disabled_tools.extend(confluence_tools)
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
if "search_knowledge_base" not in modified_disabled_tools:
modified_disabled_tools.append("search_knowledge_base")
# Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter()
tools = await build_tools_async(
@ -373,7 +388,6 @@ async def create_surfsense_deep_agent(
# Build system prompt based on agent_config, scoped to the tools actually enabled
_t0 = time.perf_counter()
_sandbox_enabled = sandbox_backend is not None
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
if agent_config is not None:
@ -382,14 +396,12 @@ async def create_surfsense_deep_agent(
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
)
@ -397,24 +409,69 @@ async def create_surfsense_deep_agent(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
# Build optional kwargs for the deep agent
deep_agent_kwargs: dict[str, Any] = {}
if sandbox_backend is not None:
deep_agent_kwargs["backend"] = sandbox_backend
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
# General-purpose subagent middleware
gp_middleware = [
TodoListMiddleware(),
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
),
create_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": gp_middleware,
}
# Main agent middleware
deepagent_middleware = [
TodoListMiddleware(),
KnowledgeBaseSearchMiddleware(
search_space_id=search_space_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
),
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
create_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
create_deep_agent,
model=llm,
create_agent,
llm,
system_prompt=final_system_prompt,
tools=tools,
system_prompt=system_prompt,
middleware=deepagent_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
middleware=[DedupHITLToolCallsMiddleware()],
**deep_agent_kwargs,
)
agent = agent.with_config(
{
"recursion_limit": 10_000,
"metadata": {
"ls_integration": "deepagents",
"versions": {"deepagents": deepagents_version},
},
}
)
_perf_log.info(
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
"[create_agent] Graph compiled (create_agent) in %.3fs",
time.perf_counter() - _t0,
)

View file

@ -0,0 +1,17 @@
"""Middleware components for the SurfSense new chat agent."""
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
__all__ = [
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
"SurfSenseFilesystemMiddleware",
]

View file

@ -0,0 +1,694 @@
"""Custom filesystem middleware for the SurfSense agent.
This middleware customizes prompts and persists write/edit operations for
`/documents/*` files into SurfSense's `Document`/`Chunk` tables.
"""
from __future__ import annotations
import asyncio
import re
from datetime import UTC, datetime
from typing import Annotated, Any
from deepagents import FilesystemMiddleware
from deepagents.backends.protocol import EditResult, WriteResult
from deepagents.backends.utils import validate_path
from deepagents.middleware.filesystem import FilesystemState
from fractional_indexing import generate_key_between
from langchain.tools import ToolRuntime
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import delete, select
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
from app.indexing_pipeline.document_chunker import chunk_text
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
# =============================================================================
# System Prompt (injected into every model call by wrap_model_call)
# =============================================================================
SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
- Read files before editing understand existing content before making changes.
- Mimic existing style, naming conventions, and patterns.
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`, `save_document`
All file paths must start with a `/`.
- ls: list files and directories at a given path.
- read_file: read a file from the filesystem.
- write_file: create a temporary file in the session (not persisted).
- edit_file: edit a file in the session (not persisted for /documents/ files).
- glob: find files matching a pattern (e.g., "**/*.xml").
- grep: search for text within files.
- save_document: **permanently** save a new document to the user's knowledge
base. Use only when the user explicitly asks to save/create a document.
## Reading Documents Efficiently
Documents are formatted as XML. Each document contains:
- `<document_metadata>` title, type, URL, etc.
- `<chunk_index>` a table of every chunk with its **line range** and a
`matched="true"` flag for chunks that matched the search query.
- `<document_content>` the actual chunks in original document order.
**Workflow**: when reading a large document, read the first ~20 lines to see
the `<chunk_index>`, identify chunks marked `matched="true"`, then use
`read_file(path, offset=<start_line>, limit=<lines>)` to jump directly to
those sections instead of reading the entire file sequentially.
Use `<chunk id='...'>` values as citation IDs in your answers.
"""
# =============================================================================
# Per-Tool Descriptions (shown to the LLM as the tool's docstring)
# =============================================================================
SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
"""
SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem.
Usage:
- By default, reads up to 100 lines from the beginning.
- Use `offset` and `limit` for pagination when files are large.
- Results include line numbers.
- Documents contain a `<chunk_index>` near the top listing every chunk with
its line range and a `matched="true"` flag for search-relevant chunks.
Read the index first, then jump to matched chunks with
`read_file(path, offset=<start_line>, limit=<num_lines>)`.
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
"""
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
Use this to create scratch/working files during the conversation. Files created
here are ephemeral and will not be saved to the user's knowledge base.
To permanently save a document to the user's knowledge base, use the
`save_document` tool instead.
"""
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
IMPORTANT:
- Read the file before editing.
- Preserve exact indentation and formatting.
- Edits to documents under `/documents/` are session-only (not persisted to the
database) because those files use an XML citation wrapper around the original
content.
"""
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
Supports standard glob patterns: `*`, `**`, `?`.
Returns absolute file paths.
"""
SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files.
Use this to locate relevant document files/chunks before reading full files.
"""
SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base.
This is an expensive operation it creates a new Document record in the
database, chunks the content, and generates embeddings for search.
Use ONLY when the user explicitly asks to save/create/store a document.
Do NOT use this for scratch work; use `write_file` for temporary files.
Args:
title: The document title (e.g., "Meeting Notes 2025-06-01").
content: The plain-text or markdown content to save. Do NOT include XML
citation wrappers pass only the actual document text.
folder_path: Optional folder path under /documents/ (e.g., "Work/Notes").
Folders are created automatically if they don't exist.
"""
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware with DB persistence for docs."""
def __init__(
self,
*,
search_space_id: int | None = None,
created_by_id: str | None = None,
tool_token_limit_before_evict: int | None = 20000,
) -> None:
self._search_space_id = search_space_id
self._created_by_id = created_by_id
super().__init__(
system_prompt=SURFSENSE_FILESYSTEM_SYSTEM_PROMPT,
custom_tool_descriptions={
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
},
tool_token_limit_before_evict=tool_token_limit_before_evict,
)
# Remove the execute tool (no sandbox backend)
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(self._create_save_document_tool())
@staticmethod
def _run_async_blocking(coro: Any) -> Any:
"""Run async coroutine from sync code path when no event loop is running."""
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return "Error: sync filesystem persistence not supported inside an active event loop."
except RuntimeError:
pass
return asyncio.run(coro)
@staticmethod
def _parse_virtual_path(file_path: str) -> tuple[list[str], str]:
"""Parse /documents/... path into folder parts and a document title."""
if not file_path.startswith("/documents/"):
return [], ""
rel = file_path[len("/documents/") :].strip("/")
if not rel:
return [], ""
parts = [part for part in rel.split("/") if part]
file_name = parts[-1]
title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name
return parts[:-1], title
async def _ensure_folder_hierarchy(
self,
*,
folder_parts: list[str],
search_space_id: int,
) -> int | None:
"""Ensure folder hierarchy exists and return leaf folder ID."""
if not folder_parts:
return None
async with shielded_async_session() as session:
parent_id: int | None = None
for name in folder_parts:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
Folder.name == name,
)
)
folder = result.scalar_one_or_none()
if folder is None:
sibling_result = await session.execute(
select(Folder.position)
.where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
)
.order_by(Folder.position.desc())
.limit(1)
)
last_position = sibling_result.scalar_one_or_none()
folder = Folder(
name=name,
position=generate_key_between(last_position, None),
parent_id=parent_id,
search_space_id=search_space_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(folder)
await session.flush()
parent_id = folder.id
await session.commit()
return parent_id
async def _persist_new_document(
self, *, file_path: str, content: str
) -> dict[str, Any] | str:
"""Persist a new NOTE document from a newly written file.
Returns a dict with document metadata on success, or an error string.
"""
if self._search_space_id is None:
return {}
folder_parts, title = self._parse_virtual_path(file_path)
if not title:
return "Error: write_file for document persistence requires path under /documents/<name>.xml"
folder_id = await self._ensure_folder_hierarchy(
folder_parts=folder_parts,
search_space_id=self._search_space_id,
)
async with shielded_async_session() as session:
content_hash = generate_content_hash(content, self._search_space_id)
existing = await session.execute(
select(Document.id).where(Document.content_hash == content_hash)
)
if existing.scalar_one_or_none() is not None:
return "Error: A document with identical content already exists."
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc = Document(
title=title,
document_type=DocumentType.NOTE,
document_metadata={"virtual_path": file_path},
content=content,
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=self._search_space_id,
folder_id=folder_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(doc)
await session.flush()
summary_embedding = embed_texts([content])[0]
doc.embedding = summary_embedding
chunk_texts = chunk_text(content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
chunks = [
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
session.add_all(chunks)
await session.commit()
return {
"id": doc.id,
"title": title,
"documentType": DocumentType.NOTE.value,
"searchSpaceId": self._search_space_id,
"folderId": folder_id,
"createdById": str(self._created_by_id)
if self._created_by_id
else None,
}
async def _persist_edited_document(
self, *, file_path: str, updated_content: str
) -> str | None:
"""Persist edits for an existing NOTE document and recreate chunks."""
if self._search_space_id is None:
return None
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc_id_from_xml: int | None = None
match = re.search(r"<document_id>\s*(\d+)\s*</document_id>", updated_content)
if match:
doc_id_from_xml = int(match.group(1))
async with shielded_async_session() as session:
doc_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
document = doc_result.scalar_one_or_none()
if document is None and doc_id_from_xml is not None:
by_id_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.id == doc_id_from_xml,
)
)
document = by_id_result.scalar_one_or_none()
if document is None:
return "Error: Could not map edited file to an existing document."
document.content = updated_content
document.source_markdown = updated_content
document.content_hash = generate_content_hash(
updated_content, self._search_space_id
)
document.updated_at = datetime.now(UTC)
if not document.document_metadata:
document.document_metadata = {}
document.document_metadata["virtual_path"] = file_path
summary_embedding = embed_texts([updated_content])[0]
document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunk_texts = chunk_text(updated_content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(
document_id=document.id, content=text, embedding=embedding
)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
await session.commit()
return None
def _create_save_document_tool(self) -> BaseTool:
"""Create save_document tool that persists a new document to the KB."""
def sync_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = self._run_async_blocking(
self._persist_new_document(file_path=virtual_path, content=content)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
async def async_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = await self._persist_new_document(
file_path=virtual_path, content=content
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
return StructuredTool.from_function(
name="save_document",
description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION,
func=sync_save_document,
coroutine=async_save_document,
)
def _create_write_file_tool(self) -> BaseTool:
"""Create write_file — ephemeral for /documents/*, persisted otherwise."""
tool_description = (
self._custom_tool_descriptions.get("write_file")
or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION
)
def sync_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.write(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = self._run_async_blocking(
self._persist_new_document(
file_path=validated_path, content=content
)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
async def async_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.awrite(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = await self._persist_new_document(
file_path=validated_path,
content=content,
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
return StructuredTool.from_function(
name="write_file",
description=tool_description,
func=sync_write_file,
coroutine=async_write_file,
)
@staticmethod
def _is_kb_document(path: str) -> bool:
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
return path.startswith("/documents/")
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (
self._custom_tool_descriptions.get("edit_file")
or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION
)
def sync_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = resolved_backend.edit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = resolved_backend.read(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_result = self._run_async_blocking(
self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
)
if isinstance(persist_result, str):
return persist_result
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
async def async_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = await resolved_backend.aedit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = await resolved_backend.aread(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_error = await self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
if persist_error:
return persist_error
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
return StructuredTool.from_function(
name="edit_file",
description=tool_description,
func=sync_edit_file,
coroutine=async_edit_file,
)

View file

@ -0,0 +1,414 @@
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
This middleware runs before the main agent loop and seeds a virtual filesystem
(`files` state) with relevant documents retrieved via hybrid search. On each
turn the filesystem is *expanded* new results merge with documents loaded
during prior turns and a synthetic ``ls`` result is injected into the message
history so the LLM is immediately aware of the current filesystem structure.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content."""
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(p for p in parts if p)
return str(content)
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
"""Convert arbitrary text into a filesystem-safe filename."""
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", " ", name)
if not name:
name = fallback
if len(name) > 180:
name = name[:180].rstrip()
if not name.lower().endswith(".xml"):
name = f"{name}.xml"
return name
def _build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
) -> str:
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
The ``<chunk_index>`` at the top of each document lists every chunk with its
line range inside ``<document_content>`` and flags chunks that directly
matched the search query (``matched="true"``). This lets the LLM jump
straight to the most relevant section via ``read_file(offset=, limit=)``
instead of reading sequentially from the start.
"""
matched = matched_chunk_ids or set()
doc_meta = document.get("document") or {}
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
url = (
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
# --- 1. Metadata header (fixed structure) ---
metadata_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{document_id}</document_id>",
f" <document_type>{document_type}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
]
# --- 2. Pre-build chunk XML strings to compute line counts ---
chunks = document.get("chunks") or []
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
if isinstance(chunks, list):
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_id = chunk.get("chunk_id") or chunk.get("id")
chunk_content = str(chunk.get("content", "")).strip()
if not chunk_content:
continue
if chunk_id is None:
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
else:
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
chunk_entries.append((chunk_id, xml))
# --- 3. Compute line numbers for every chunk ---
# Layout (1-indexed lines for read_file):
# metadata_lines -> len(metadata_lines) lines
# <chunk_index> -> 1 line
# index entries -> len(chunk_entries) lines
# </chunk_index> -> 1 line
# (empty line) -> 1 line
# <document_content> -> 1 line
# chunk xml lines…
# </document_content> -> 1 line
# </document> -> 1 line
index_overhead = (
1 + len(chunk_entries) + 1 + 1 + 1
) # tags + empty + <document_content>
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
current_line = first_chunk_line
index_entry_lines: list[str] = []
for cid, xml_str in chunk_entries:
num_lines = xml_str.count("\n") + 1
end_line = current_line + num_lines - 1
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
if cid is not None:
index_entry_lines.append(
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
)
else:
index_entry_lines.append(
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
)
current_line = end_line + 1
# --- 4. Assemble final XML ---
lines = metadata_lines.copy()
lines.append("<chunk_index>")
lines.extend(index_entry_lines)
lines.append("</chunk_index>")
lines.append("")
lines.append("<document_content>")
for _, xml_str in chunk_entries:
lines.append(xml_str)
lines.extend(["</document_content>", "</document>"])
return "\n".join(lines)
async def _get_folder_paths(
session: AsyncSession, search_space_id: int
) -> dict[int, str]:
"""Return a map of folder_id -> virtual folder path under /documents."""
result = await session.execute(
select(Folder.id, Folder.name, Folder.parent_id).where(
Folder.search_space_id == search_space_id
)
)
rows = result.all()
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
cache: dict[int, str] = {}
def resolve_path(folder_id: int) -> str:
if folder_id in cache:
return cache[folder_id]
parts: list[str] = []
cursor: int | None = folder_id
visited: set[int] = set()
while cursor is not None and cursor in by_id and cursor not in visited:
visited.add(cursor)
entry = by_id[cursor]
parts.append(
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
".xml"
)
)
cursor = entry["parent_id"]
parts.reverse()
path = "/documents/" + "/".join(parts) if parts else "/documents"
cache[folder_id] = path
return path
for folder_id in by_id:
resolve_path(folder_id)
return cache
def _build_synthetic_ls(
existing_files: dict[str, Any] | None,
new_files: dict[str, Any],
) -> tuple[AIMessage, ToolMessage]:
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
Paths are listed with *new* (rank-ordered) files first, then existing files
that were already in state from prior turns.
"""
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
doc_paths = [
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
]
new_set = set(new_files)
new_paths = [p for p in doc_paths if p in new_set]
old_paths = [p for p in doc_paths if p not in new_set]
ordered = new_paths + old_paths
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
ai_msg = AIMessage(
content="",
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
)
tool_msg = ToolMessage(
content=str(ordered) if ordered else "No documents found.",
tool_call_id=tool_call_id,
)
return ai_msg, tool_msg
def _resolve_search_types(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> list[str] | None:
"""Build a flat list of document-type strings for the chunk retriever.
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
old documents indexed under Composio names are still found.
Returns ``None`` when no filtering is desired (search all types).
"""
types: set[str] = set()
if available_document_types:
types.update(available_document_types)
if available_connectors:
types.update(available_connectors)
if not types:
return None
expanded: set[str] = set(types)
for t in types:
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
if legacy:
expanded.add(legacy)
return list(expanded) if expanded else None
async def search_knowledge_base(
*,
query: str,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base.
Uses one ``ChucksHybridSearchRetriever`` call across all document types
instead of fanning out per-connector. This reduces the number of DB
queries from ~10 to 2 (one RRF query + one chunk fetch).
"""
if not query:
return []
[embedding] = embed_texts([query])
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
async with shielded_async_session() as session:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=doc_types,
query_embedding=embedding.tolist(),
)
return results[:top_k]
async def build_scoped_filesystem(
*,
documents: Sequence[dict[str, Any]],
search_space_id: int,
) -> dict[str, dict[str, str]]:
"""Build a StateBackend-compatible files dict from search results."""
async with shielded_async_session() as session:
folder_paths = await _get_folder_paths(session, search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in documents
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
doc_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
}
files: dict[str, dict[str, str]] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
doc_id = doc_meta.get("id")
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
base_folder = folder_paths.get(folder_id, "/documents")
file_name = _safe_filename(title)
path = f"{base_folder}/{file_name}"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
return files
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
tools = ()
def __init__(
self,
*,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> None:
self.search_space_id = search_space_id
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, HumanMessage):
return None
user_text = _extract_text_from_message(last_message).strip()
if not user_text:
return None
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
search_results = await search_knowledge_base(
query=user_text,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
)
new_files = await build_scoped_filesystem(
documents=search_results,
search_space_id=self.search_space_id,
)
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(new_files),
len(new_files) + len(existing_files or {}),
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}

View file

@ -25,6 +25,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
<knowledge_base_only_policy>
CRITICAL RULE KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs.
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission.
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
1. Inform the user that you could not find relevant information in their knowledge base.
2. Ask the user: "Would you like me to answer from my general knowledge instead?"
3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes.
- This policy does NOT apply to:
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
</system_instruction>
"""
@ -41,6 +56,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
<knowledge_base_only_policy>
CRITICAL RULE KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs.
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission.
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
1. Inform the team that you could not find relevant information in the shared knowledge base.
2. Ask: "Would you like me to answer from my general knowledge instead?"
3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes.
- This policy does NOT apply to:
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
</system_instruction>
"""
@ -67,15 +97,6 @@ _TOOLS_PREAMBLE = """
<tools>
You have access to the following tools:
CRITICAL BEHAVIORAL RULE SEARCH FIRST, ANSWER LATER:
For ANY user query that is ambiguous, open-ended, or could potentially have relevant context in the
knowledge base, you MUST call `search_knowledge_base` BEFORE attempting to answer from your own
general knowledge. This includes (but is not limited to) questions about concepts, topics, projects,
people, events, recommendations, or anything the user might have stored notes/documents about.
Only fall back to your own general knowledge if the search returns NO relevant results.
Do NOT skip the search and answer directly the user's knowledge base may contain personalized,
up-to-date, or domain-specific information that is more relevant than your general training data.
IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it.
Do NOT claim you can do something if the corresponding tool is not listed.
@ -92,29 +113,6 @@ _TOOL_INSTRUCTIONS["search_surfsense_docs"] = """
- Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123])
"""
_TOOL_INSTRUCTIONS["search_knowledge_base"] = """
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
- DEFAULT ACTION: For any user question or ambiguous query, ALWAYS call this tool first to check
for relevant context before answering from general knowledge. When in doubt, search.
- IMPORTANT: When searching for information (meetings, schedules, notes, tasks, etc.), ALWAYS search broadly
across ALL sources first by omitting connectors_to_search. The user may store information in various places
including calendar apps, note-taking apps (Obsidian, Notion), chat apps (Slack, Discord), and more.
- This tool searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
For real-time web search (current events, news, live data), use the `web_search` tool instead.
- FALLBACK BEHAVIOR: If the search returns no relevant results, you MAY then answer using your own
general knowledge, but clearly indicate that no matching information was found in the knowledge base.
- Only narrow to specific connectors if the user explicitly asks (e.g., "check my Slack" or "in my calendar").
- Personal notes in Obsidian, Notion, or NOTE often contain schedules, meeting times, reminders, and other
important information that may not be in calendars.
- Args:
- query: The search query - be specific and include key terms
- top_k: Number of results to retrieve (default: 10)
- start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
- Returns: Formatted string with relevant documents and their content
"""
_TOOL_INSTRUCTIONS["generate_podcast"] = """
- generate_podcast: Generate an audio podcast from provided content.
- Use this when the user asks to create, generate, or make a podcast.
@ -163,8 +161,8 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
* For source_strategy="kb_search": Can be empty or minimal the tool handles searching internally.
* For source_strategy="auto": Include what you have; the tool searches KB if it's not enough.
- source_strategy: Controls how the tool collects source material. One of:
* "conversation" The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. Do NOT call search_knowledge_base separately.
* "kb_search" The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. Do NOT call search_knowledge_base separately.
* "conversation" The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content.
* "kb_search" The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries.
* "auto" Use source_content if sufficient, otherwise fall back to internal KB search using search_queries.
* "provided" Use only what is in source_content (default, backward-compatible).
- search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated.
@ -176,11 +174,11 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
- The report is generated immediately in Markdown and displayed inline in the chat.
- Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report.
- SOURCE STRATEGY DECISION (HIGH PRIORITY follow this exactly):
* If the conversation already has substantive Q&A / discussion on the topic use source_strategy="conversation" with a comprehensive summary as source_content. Do NOT call search_knowledge_base first.
* If the user wants a report on a topic not yet discussed use source_strategy="kb_search" with targeted search_queries. Do NOT call search_knowledge_base first.
* If the conversation already has substantive Q&A / discussion on the topic use source_strategy="conversation" with a comprehensive summary as source_content.
* If the user wants a report on a topic not yet discussed use source_strategy="kb_search" with targeted search_queries.
* If you have some content but might need more use source_strategy="auto" with both source_content and search_queries.
* When revising an existing report (parent_report_id set) and the conversation has relevant context use source_strategy="conversation". The revision will use the previous report content plus your source_content.
* NEVER call search_knowledge_base and then pass its results to generate_report. The tool handles KB search internally.
* NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally.
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
"""
@ -204,7 +202,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
* When search_knowledge_base returned insufficient data and the user wants more
* When preloaded `/documents/` data is insufficient and the user wants more
- Trigger scenarios:
* "Read this article and summarize it"
* "What does this page say about X?"
@ -366,23 +364,6 @@ _MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
# Per-tool examples keyed by tool name. Only examples for enabled tools are included.
_TOOL_EXAMPLES: dict[str, str] = {}
_TOOL_EXAMPLES["search_knowledge_base"] = """
- User: "What time is the team meeting today?"
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
- DO NOT limit to just calendar - the info might be in notes!
- User: "When is my gym session?"
- Call: `search_knowledge_base(query="gym session time schedule")` (searches ALL sources)
- User: "Fetch all my notes and what's in them?"
- Call: `search_knowledge_base(query="*", top_k=50, connectors_to_search=["NOTE"])`
- User: "What did I discuss on Slack last week about the React migration?"
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
- User: "Check my Obsidian notes for meeting notes"
- Call: `search_knowledge_base(query="meeting notes", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
- User: "search me current usd to inr rate"
- Call: `web_search(query="current USD to INR exchange rate")`
- Then answer using the returned live web results with citations.
"""
_TOOL_EXAMPLES["search_surfsense_docs"] = """
- User: "How do I install SurfSense?"
- Call: `search_surfsense_docs(query="installation setup")`
@ -400,8 +381,7 @@ _TOOL_EXAMPLES["generate_podcast"] = """
- User: "Create a podcast summary of this conversation"
- Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
- User: "Make a podcast about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")`
"""
_TOOL_EXAMPLES["generate_video_presentation"] = """
@ -410,8 +390,7 @@ _TOOL_EXAMPLES["generate_video_presentation"] = """
- User: "Create slides summarizing this conversation"
- Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
- User: "Make a video presentation about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_video_presentation(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", video_title="Quantum Computing Explained")`
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")`
"""
_TOOL_EXAMPLES["generate_report"] = """
@ -471,7 +450,6 @@ _TOOL_EXAMPLES["web_search"] = """
# All tool names that have prompt instructions (order matters for prompt readability)
_ALL_TOOL_NAMES_ORDERED = [
"search_surfsense_docs",
"search_knowledge_base",
"web_search",
"generate_podcast",
"generate_video_presentation",
@ -650,87 +628,6 @@ However, from your video learning, it's important to note that asyncio is not su
</citation_instructions>
"""
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
SANDBOX_EXECUTION_INSTRUCTIONS = """
<code_execution>
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
## CRITICAL — CODE-FIRST RULE
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
- **Creating a chart, plot, graph, or visualization** Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
- **Data analysis, statistics, or computation** Write code to compute the answer. Do not do math by hand in text.
- **Generating or transforming files** (CSV, PDF, images, etc.) Write code to create the file.
- **Running, testing, or debugging code** Execute it in the sandbox.
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
Example (CORRECT):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Immediately execute Python code (matplotlib) to generate the pie chart
3. Return the downloadable file + brief description
Example (WRONG):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Print a text table with percentages and ask the user if they want a chart NEVER do this
## When to Use Code Execution
Use the sandbox when the task benefits from actually running code rather than just describing it:
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
- **Calculations**: Math, financial modeling, unit conversions, simulations
- **Code validation**: Run and test code snippets the user provides or asks about
- **File processing**: Parse, transform, or convert data files
- **Quick prototyping**: Demonstrate working code for the user's problem
- **Package exploration**: Install and test libraries the user is evaluating
## When NOT to Use Code Execution
Do not use the sandbox for:
- Answering factual questions from your own knowledge
- Summarizing or explaining concepts
- Simple formatting or text generation tasks
- Tasks that don't require running code to answer
## Package Management
- Use `pip install <package>` to install Python packages as needed
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
- Always verify a package installed successfully before using it
## Working Guidelines
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
- **Iterative approach**: For complex tasks, break work into steps write code, run it, check output, refine
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
- **Be efficient**: Install packages once per session. Combine related commands when possible.
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
## Sharing Generated Files
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
- **DO NOT use markdown image syntax** for files created inside the sandbox. Sandbox files are not accessible via public URLs and will show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
- Always describe what the file contains in your response text so the user knows what they are downloading.
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
## Data Analytics Best Practices
When the user asks you to analyze data:
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
2. Clean and validate before computing (handle nulls, check types)
3. Perform the analysis and present results clearly
4. Offer follow-up insights or visualizations when appropriate
</code_execution>
"""
# Anti-citation prompt - used when citations are disabled
# This explicitly tells the model NOT to include citations
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
@ -756,7 +653,6 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
) -> str:
@ -767,12 +663,10 @@ def build_surfsense_system_prompt(
- Default system instructions
- Tools instructions (only for enabled tools)
- Citation instructions enabled
- Sandbox execution instructions (when sandbox_enabled=True)
Args:
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
@ -786,13 +680,7 @@ def build_surfsense_system_prompt(
visibility, enabled_tool_names, disabled_tool_names
)
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
return system_instructions + tools_instructions + citation_instructions
def build_configurable_system_prompt(
@ -801,18 +689,16 @@ def build_configurable_system_prompt(
citations_enabled: bool = True,
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
The prompt is composed of up to four parts:
The prompt is composed of three parts:
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
2. Tools Instructions - only for enabled tools, with a note about disabled ones
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
4. Sandbox Execution Instructions - when sandbox_enabled=True
Args:
custom_system_instructions: Custom system instructions to use. If empty/None and
@ -824,7 +710,6 @@ def build_configurable_system_prompt(
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
@ -856,14 +741,7 @@ def build_configurable_system_prompt(
else SURFSENSE_NO_CITATION_INSTRUCTIONS
)
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
return system_instructions + tools_instructions + citation_instructions
def get_default_system_instructions() -> str:

View file

@ -5,7 +5,6 @@ This module contains all the tools available to the SurfSense agent.
To add a new tool, see the documentation in registry.py.
Available tools:
- search_knowledge_base: Search the user's personal knowledge base
- search_surfsense_docs: Search Surfsense documentation for usage help
- generate_podcast: Generate audio podcasts from content
- generate_video_presentation: Generate video presentations with slides and narration
@ -20,7 +19,6 @@ Available tools:
from .generate_image import create_generate_image_tool
from .knowledge_base import (
CONNECTOR_DESCRIPTIONS,
create_search_knowledge_base_tool,
format_documents_for_context,
search_knowledge_base_async,
)
@ -52,7 +50,6 @@ __all__ = [
"create_recall_memory_tool",
"create_save_memory_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
"create_search_surfsense_docs_tool",
"format_documents_for_context",
"get_all_tool_names",

View file

@ -14,6 +14,20 @@ from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def _is_date_only(value: str) -> bool:
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
return len(value) <= 10 and "T" not in value
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
"""Build a Google Calendar start/end body using ``date`` for all-day
events and ``dateTime`` for timed events."""
if _is_date_only(value):
return {"date": value}
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
return {"dateTime": value, "timeZone": tz}
def create_update_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
@ -255,25 +269,11 @@ def create_update_calendar_event_tool(
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
update_body["start"] = _build_time_body(
final_new_start_datetime, context
)
update_body["start"] = {
"dateTime": final_new_start_datetime,
"timeZone": tz,
}
if final_new_end_datetime is not None:
tz = (
context.get("timezone", "UTC")
if isinstance(context, dict)
else "UTC"
)
update_body["end"] = {
"dateTime": final_new_end_datetime,
"timeZone": tz,
}
update_body["end"] = _build_time_body(final_new_end_datetime, context)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:

View file

@ -5,7 +5,6 @@ This module provides:
- Connector constants and normalization
- Async knowledge base search across multiple connectors
- Document formatting for LLM context
- Tool factory for creating search_knowledge_base tools
"""
import asyncio
@ -16,8 +15,6 @@ import time
from datetime import datetime
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
@ -619,9 +616,76 @@ async def search_knowledge_base_async(
perf = get_perf_logger()
t0 = time.perf_counter()
deduplicated = await search_knowledge_base_raw_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=start_date,
end_date=end_date,
available_connectors=available_connectors,
available_document_types=available_document_types,
)
if not deduplicated:
return "No documents found in the knowledge base. The search space has no indexed content yet."
# Use browse chunk cap for degenerate queries, otherwise adaptive chunking.
max_chunks_per_doc = (
_BROWSE_MAX_CHUNKS_PER_DOC if _is_degenerate_query(query) else 0
)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
deduplicated,
max_chars=output_budget,
max_chunks_per_doc=max_chunks_per_doc,
)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0,
len(deduplicated),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
async def search_knowledge_base_raw_async(
query: str,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
connectors_to_search: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""Search knowledge base and return raw document dicts (no XML formatting)."""
perf = get_perf_logger()
t0 = time.perf_counter()
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
# Preserve the public signature for compatibility even if values are unused.
_ = (db_session, connector_service)
from app.agents.new_chat.utils import resolve_date_range
resolved_start_date, resolved_end_date = resolve_date_range(
@ -631,144 +695,76 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
# --- Optimization 1: skip connectors that have zero indexed documents ---
if available_document_types:
doc_types_set = set(available_document_types)
before_count = len(connectors)
connectors = [
c
for c in connectors
if c in doc_types_set
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
]
skipped = before_count - len(connectors)
if skipped:
perf.info(
"[kb_search] skipped %d empty connectors (had %d, now %d)",
skipped,
before_count,
len(connectors),
)
perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors),
connectors[:5],
search_space_id,
top_k,
)
# --- Fast-path: no connectors left after filtering ---
if not connectors:
perf.info(
"[kb_search] TOTAL in %.3fs — no connectors to search, returning empty",
time.perf_counter() - t0,
)
return "No documents found in the knowledge base. The search space has no indexed content yet."
return []
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
# yields an empty tsquery, so both retrieval signals are useless.
# Fall back to a recency-ordered browse that returns diverse results.
if _is_degenerate_query(query):
perf.info(
"[kb_search] degenerate query %r detected - falling back to recency browse",
"[kb_search_raw] degenerate query %r detected - recency browse",
query,
)
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
expanded_browse = []
for c in browse_connectors:
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
for connector in browse_connectors:
if connector is not None and connector in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([connector, NATIVE_TO_LEGACY_DOCTYPE[connector]])
else:
expanded_browse.append(c)
expanded_browse.append(connector)
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=c,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for c in expanded_browse
]
)
for docs in browse_results:
all_documents.extend(docs)
# Skip dedup + formatting below (browse already returns unique docs)
# but still cap output budget.
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
all_documents,
max_chars=output_budget,
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
)
perf.info(
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
"budget=%d space=%d",
time.perf_counter() - t0,
len(all_documents),
len(result),
output_budget,
search_space_id,
)
return result
# --- Optimization 2: compute the query embedding once, share across all local searches ---
from app.config import config as app_config
t_embed = time.perf_counter()
precomputed_embedding = app_config.embedding_model_instance.embed(query)
perf.info(
"[kb_search] shared embedding computed in %.3fs",
time.perf_counter() - t_embed,
)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
try:
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=precomputed_embedding,
)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
for connector in expanded_browse
]
)
for docs in browse_results:
all_documents.extend(docs)
else:
if query_embedding is None:
from app.config import config as app_config
t_gather = time.perf_counter()
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
perf.info(
"[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
)
for chunks in connector_results:
all_documents.extend(chunks)
query_embedding = app_config.embedding_model_instance.embed(query)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
try:
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
return await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=query_embedding,
)
except Exception as exc:
perf.warning("[kb_search_raw] connector=%s FAILED: %s", connector, exc)
return []
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
for docs in connector_results:
all_documents.extend(docs)
# Deduplicate primarily by document ID. Only fall back to content hashing
# when a document has no ID.
seen_doc_ids: set[Any] = set()
seen_content_hashes: set[int] = set()
deduplicated: list[dict[str, Any]] = []
@ -785,7 +781,6 @@ async def search_knowledge_base_async(
chunk_texts.append(chunk_content)
if chunk_texts:
return hash("||".join(chunk_texts))
flat_content = (document.get("content") or "").strip()
if flat_content:
return hash(flat_content)
@ -793,216 +788,24 @@ async def search_knowledge_base_async(
for doc in all_documents:
doc_id = (doc.get("document", {}) or {}).get("id")
if doc_id is not None:
if doc_id in seen_doc_ids:
continue
seen_doc_ids.add(doc_id)
deduplicated.append(doc)
continue
content_hash = _content_fingerprint(doc)
if content_hash is not None and content_hash in seen_content_hashes:
continue
if content_hash is not None:
if content_hash in seen_content_hashes:
continue
seen_content_hashes.add(content_hash)
deduplicated.append(doc)
# Sort by RRF score so the most relevant documents from ANY connector
# appear first, preventing budget truncation from hiding top results.
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(deduplicated, max_chars=output_budget)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
deduplicated.sort(key=lambda doc: doc.get("score", 0), reverse=True)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
"[kb_search_raw] done in %.3fs total=%d deduped=%d",
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
"""
Build the connector documentation section for the tool docstring.
Args:
available_connectors: List of available connector types, or None for all
Returns:
Formatted docstring section listing available connectors
"""
connectors = available_connectors if available_connectors else list(_ALL_CONNECTORS)
lines = []
for connector in connectors:
# Skip internal names, prefer user-facing aliases
if connector == "CRAWLED_URL":
# Show as WEBCRAWLER_CONNECTOR for user-facing docs
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
lines.append(f"- WEBCRAWLER_CONNECTOR: {description}")
else:
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
lines.append(f"- {connector}: {description}")
return "\n".join(lines)
# =============================================================================
# Tool Input Schema
# =============================================================================
class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool."""
query: str = Field(
description=(
"The search query - use specific natural language terms. "
"NEVER use wildcards like '*' or '**'; instead describe what you want "
"(e.g. 'recent meeting notes' or 'project architecture overview')."
),
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
)
start_date: str | None = Field(
default=None,
description="Optional ISO date/datetime (e.g. '2025-12-12' or '2025-12-12T00:00:00+00:00')",
)
end_date: str | None = Field(
default=None,
description="Optional ISO date/datetime (e.g. '2025-12-19' or '2025-12-19T23:59:59+00:00')",
)
connectors_to_search: list[str] | None = Field(
default=None,
description="Optional list of connector enums to search. If omitted, searches all available.",
)
def create_search_knowledge_base_tool(
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> StructuredTool:
"""
Factory function to create the search_knowledge_base tool with injected dependencies.
Args:
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
available_connectors: Optional list of connector types available in the search space.
Used to dynamically generate the tool docstring.
available_document_types: Optional list of document types that have data in the search space.
Used to inform the LLM about what data exists.
max_input_tokens: Model context window (tokens) from litellm model info.
Used to dynamically size tool output.
Returns:
A configured StructuredTool instance
"""
# Build connector documentation dynamically
connector_docs = _build_connector_docstring(available_connectors)
# Build context about available document types
doc_types_info = ""
if available_document_types:
doc_types_info = f"""
## Document types with indexed content in this search space
The following document types have content available for search:
{", ".join(available_document_types)}
Focus searches on these types for best results."""
# Build the dynamic description for the tool
# This is what the LLM sees when deciding whether/how to use the tool
dynamic_description = f"""Search the user's personal knowledge base for relevant information.
Use this tool to find documents, notes, files, web pages, and other content the user has indexed.
This searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
For real-time web search (current events, news, live data), use the `web_search` tool instead.
IMPORTANT:
- Always craft specific, descriptive search queries using natural language keywords.
Good: "quarterly sales report Q3", "Python API authentication design".
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
## Available connector enums for `connectors_to_search`
{connector_docs}
NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type `CRAWLED_URL`."""
# Capture for closure
_available_connectors = available_connectors
_available_document_types = available_document_types
async def _search_knowledge_base_impl(
query: str,
top_k: int = 10,
start_date: str | None = None,
end_date: str | None = None,
connectors_to_search: list[str] | None = None,
) -> str:
"""Implementation function for knowledge base search."""
from app.agents.new_chat.utils import parse_date_or_datetime
parsed_start: datetime | None = None
parsed_end: datetime | None = None
if start_date:
parsed_start = parse_date_or_datetime(start_date)
if end_date:
parsed_end = parse_date_or_datetime(end_date)
return await search_knowledge_base_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
available_document_types=_available_document_types,
max_input_tokens=max_input_tokens,
)
# Create StructuredTool with dynamic description
# This properly sets the description that the LLM sees
tool = StructuredTool(
name="search_knowledge_base",
description=dynamic_description,
coroutine=_search_knowledge_base_impl,
args_schema=SearchKnowledgeBaseInput,
)
return tool
return deduplicated

View file

@ -71,7 +71,6 @@ from .jira import (
create_delete_jira_issue_tool,
create_update_jira_issue_tool,
)
from .knowledge_base import create_search_knowledge_base_tool
from .linear import (
create_create_linear_issue_tool,
create_delete_linear_issue_tool,
@ -128,23 +127,6 @@ class ToolDefinition:
# Registry of all built-in tools
# Contributors: Add your new tools here!
BUILTIN_TOOLS: list[ToolDefinition] = [
# Core tool - searches the user's knowledge base
# Now supports dynamic connector/document type discovery
ToolDefinition(
name="search_knowledge_base",
description="Search the user's personal knowledge base for relevant information",
factory=lambda deps: create_search_knowledge_base_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
connector_service=deps["connector_service"],
# Optional: dynamically discovered connectors/document types
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
max_input_tokens=deps.get("max_input_tokens"),
),
requires=["search_space_id", "db_session", "connector_service"],
# Note: available_connectors and available_document_types are optional
),
# Podcast generation tool
ToolDefinition(
name="generate_podcast",
@ -168,8 +150,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["search_space_id", "db_session", "thread_id"],
),
# Report generation tool (inline, short-lived sessions for DB ops)
# Supports internal KB search via source_strategy so the agent doesn't
# need to call search_knowledge_base separately before generating.
# Supports internal KB search via source_strategy so the agent does not
# need a separate search step before generating.
ToolDefinition(
name="generate_report",
description="Generate a structured report from provided content and export it",
@ -551,7 +533,7 @@ def build_tools(
tools = build_tools(deps)
# Use only specific tools
tools = build_tools(deps, enabled_tools=["search_knowledge_base"])
tools = build_tools(deps, enabled_tools=["generate_report"])
# Use defaults but disable podcast
tools = build_tools(deps, disabled_tools=["generate_podcast"])

View file

@ -584,8 +584,8 @@ def create_generate_report_tool(
search_space_id: The user's search space ID
thread_id: The chat thread ID for associating the report
connector_service: Optional connector service for internal KB search.
When provided, the tool can search the knowledge base without the
agent having to call search_knowledge_base separately.
When provided, the tool can search the knowledge base internally
(used by the "kb_search" and "auto" source strategies).
available_connectors: Optional list of connector types available in the
search space (used to scope internal KB searches).
@ -639,12 +639,13 @@ def create_generate_report_tool(
SOURCE STRATEGY (how to collect source material):
- source_strategy="conversation" The conversation already has
enough context (prior Q&A, pasted text, uploaded files, scraped
webpages). Pass a thorough summary as source_content.
NEVER call search_knowledge_base separately first.
enough context (prior Q&A, filesystem exploration, pasted text,
uploaded files, scraped webpages). Pass a thorough summary as
source_content.
- source_strategy="kb_search" Search the knowledge base
internally. Provide 1-5 targeted search_queries. The tool
handles searching do NOT call search_knowledge_base first.
handles searching internally do NOT manually read and dump
/documents/ files into source_content.
- source_strategy="provided" Use only what is in source_content
(default, backward-compatible).
- source_strategy="auto" Use source_content if it has enough
@ -1064,6 +1065,7 @@ def create_generate_report_tool(
"title": topic,
"word_count": metadata.get("word_count", 0),
"is_revision": bool(parent_report_content),
"report_markdown": report_content,
"message": f"Report generated successfully: {topic}",
}

View file

@ -2,13 +2,14 @@
from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token
from .client import GoogleDriveClient
from .content_extractor import download_and_process_file
from .content_extractor import download_and_extract_content, download_and_process_file
from .credentials import get_valid_credentials, validate_credentials
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
__all__ = [
"GoogleDriveClient",
"categorize_change",
"download_and_extract_content",
"download_and_process_file",
"fetch_all_changes",
"get_file_by_id",

View file

@ -84,22 +84,50 @@ async def get_changes(
return [], None, f"Error getting changes: {e!s}"
async def _is_descendant_of(
client: GoogleDriveClient,
parent_ids: list[str],
target_folder_id: str,
max_depth: int = 20,
) -> bool:
"""Walk up the parent chain to check if any ancestor is *target_folder_id*."""
visited: set[str] = set()
to_check = list(parent_ids)
for _ in range(max_depth):
if not to_check:
return False
current = to_check.pop(0)
if current in visited:
continue
visited.add(current)
if current == target_folder_id:
return True
try:
service = await client.get_service()
meta = (
service.files()
.get(fileId=current, fields="parents", supportsAllDrives=True)
.execute()
)
grandparents = meta.get("parents", [])
to_check.extend(grandparents)
except Exception:
continue
return False
async def _filter_changes_by_folder(
client: GoogleDriveClient,
changes: list[dict[str, Any]],
folder_id: str,
) -> list[dict[str, Any]]:
"""
Filter changes to only include files within the specified folder.
Args:
client: GoogleDriveClient instance
changes: List of changes from API
folder_id: Folder ID to filter by
Returns:
Filtered list of changes
"""
"""Filter changes to only include files within the specified folder
(direct children or nested descendants)."""
filtered = []
for change in changes:
@ -108,14 +136,8 @@ async def _filter_changes_by_folder(
filtered.append(change)
continue
# Check if file is in the folder (or subfolder)
parents = file.get("parents", [])
if folder_id in parents:
filtered.append(change)
else:
# Check if any parent is a descendant of folder_id
# This is a simplified check - full implementation would traverse hierarchy
# For now, we'll include it and let indexer validate
if folder_id in parents or await _is_descendant_of(client, parents, folder_id):
filtered.append(change)
return filtered

View file

@ -1,9 +1,15 @@
"""Google Drive API client."""
import asyncio
import io
import logging
import threading
import time
from typing import Any
import httplib2
from google.oauth2.credentials import Credentials
from google_auth_httplib2 import AuthorizedHttp
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload
@ -12,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from .credentials import get_valid_credentials
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
logger = logging.getLogger(__name__)
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
"""Create a per-thread HTTP transport so concurrent downloads don't share
the same ``httplib2.Http`` (which is not thread-safe)."""
return AuthorizedHttp(credentials, http=httplib2.Http())
class GoogleDriveClient:
"""Client for Google Drive API operations."""
@ -34,7 +48,9 @@ class GoogleDriveClient:
self.session = session
self.connector_id = connector_id
self._credentials = credentials
self._resolved_credentials: Credentials | None = None
self.service = None
self._service_lock = asyncio.Lock()
async def get_service(self):
"""
@ -49,17 +65,22 @@ class GoogleDriveClient:
if self.service:
return self.service
try:
if self._credentials:
credentials = self._credentials
else:
credentials = await get_valid_credentials(
self.session, self.connector_id
)
self.service = build("drive", "v3", credentials=credentials)
return self.service
except Exception as e:
raise Exception(f"Failed to create Google Drive service: {e!s}") from e
async with self._service_lock:
if self.service:
return self.service
try:
if self._credentials:
credentials = self._credentials
else:
credentials = await get_valid_credentials(
self.session, self.connector_id
)
self._resolved_credentials = credentials
self.service = build("drive", "v3", credentials=credentials)
return self.service
except Exception as e:
raise Exception(f"Failed to create Google Drive service: {e!s}") from e
async def list_files(
self,
@ -134,6 +155,37 @@ class GoogleDriveClient:
except Exception as e:
return None, f"Error getting file metadata: {e!s}"
@staticmethod
def _sync_download_file(
service,
file_id: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking download — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
finally:
logger.info(
f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
"""
Download binary file content.
@ -144,27 +196,96 @@ class GoogleDriveClient:
Returns:
Tuple of (file content bytes, error message)
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file,
service,
file_id,
self._resolved_credentials,
)
@staticmethod
def _sync_download_file_to_disk(
service,
file_id: str,
dest_path: str,
chunksize: int,
credentials: Credentials,
) -> str | None:
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
try:
service = await self.get_service()
request = service.files().get_media(fileId=file_id)
import io
fh = io.BytesIO()
from googleapiclient.http import MediaIoBaseDownload
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
with open(dest_path, "wb") as fh:
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
done = False
while not done:
_, done = downloader.next_chunk()
return None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
return f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
return f"Error downloading file: {e!s}"
finally:
logger.info(
f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file_to_disk(
self,
file_id: str,
dest_path: str,
chunksize: int = 5 * 1024 * 1024,
) -> str | None:
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
Returns error message on failure, None on success.
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file_to_disk,
service,
file_id,
dest_path,
chunksize,
self._resolved_credentials,
)
@staticmethod
def _sync_export_google_file(
service,
file_id: str,
mime_type: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking export — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[export] START file={file_id} thread={thread}")
try:
http = _build_thread_http(credentials)
content = (
service.files()
.export(fileId=file_id, mimeType=mime_type)
.execute(http=http)
)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
finally:
logger.info(
f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def export_google_file(
self, file_id: str, mime_type: str
@ -179,23 +300,14 @@ class GoogleDriveClient:
Returns:
Tuple of (exported content as bytes, error message)
"""
try:
service = await self.get_service()
content = (
service.files().export(fileId=file_id, mimeType=mime_type).execute()
)
# Content is already bytes from the API
# Keep as bytes to support both text and binary formats (like PDF)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
service = await self.get_service()
return await asyncio.to_thread(
self._sync_export_google_file,
service,
file_id,
mime_type,
self._resolved_credentials,
)
async def create_file(
self,

View file

@ -1,8 +1,12 @@
"""Content extraction for Google Drive files."""
import asyncio
import contextlib
import logging
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
@ -12,11 +16,195 @@ from app.db import Log
from app.services.task_logging_service import TaskLoggingService
from .client import GoogleDriveClient
from .file_types import get_export_mime_type, is_google_workspace_file, should_skip_file
from .file_types import (
get_export_mime_type,
get_extension_from_mime,
is_google_workspace_file,
should_skip_file,
)
logger = logging.getLogger(__name__)
async def download_and_extract_content(
client: GoogleDriveClient,
file: dict[str, Any],
) -> tuple[str | None, dict[str, Any], str | None]:
"""Download a Google Drive file and extract its content as markdown.
ETL only -- no DB writes, no indexing, no summarization.
Returns:
(markdown_content, drive_metadata, error_message)
On success error_message is None.
"""
file_id = file.get("id")
file_name = file.get("name", "Unknown")
mime_type = file.get("mimeType", "")
if should_skip_file(mime_type):
return None, {}, f"Skipping {mime_type}"
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
drive_metadata: dict[str, Any] = {
"google_drive_file_id": file_id,
"google_drive_file_name": file_name,
"google_drive_mime_type": mime_type,
"source_connector": "google_drive",
}
if "modifiedTime" in file:
drive_metadata["modified_time"] = file["modifiedTime"]
if "createdTime" in file:
drive_metadata["created_time"] = file["createdTime"]
if "size" in file:
drive_metadata["file_size"] = file["size"]
if "webViewLink" in file:
drive_metadata["web_view_link"] = file["webViewLink"]
if "md5Checksum" in file:
drive_metadata["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
drive_metadata["exported_as"] = export_ext.lstrip(".") if export_ext else "pdf"
drive_metadata["original_workspace_type"] = mime_type.split(".")[-1]
temp_file_path = None
try:
if is_google_workspace_file(mime_type):
export_mime = get_export_mime_type(mime_type)
if not export_mime:
return (
None,
drive_metadata,
f"Cannot export Google Workspace type: {mime_type}",
)
content_bytes, error = await client.export_google_file(file_id, export_mime)
if error:
return None, drive_metadata, error
extension = get_extension_from_mime(export_mime) or ".pdf"
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
tmp.write(content_bytes)
temp_file_path = tmp.name
else:
extension = (
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
temp_file_path = tmp.name
error = await client.download_file_to_disk(file_id, temp_file_path)
if error:
return None, drive_metadata, error
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
return markdown, drive_metadata, None
except Exception as e:
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
return None, drive_metadata, str(e)
finally:
if temp_file_path and os.path.exists(temp_file_path):
with contextlib.suppress(Exception):
os.unlink(temp_file_path)
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service."""
lower = filename.lower()
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from litellm import atranscription
from app.config import config as app_config
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
# Document files -- use configured ETL service
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(
file_path=file_path, estimated_pages=50
)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(
f"[docling] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
async def download_and_process_file(
client: GoogleDriveClient,
file: dict[str, Any],
@ -68,14 +256,15 @@ async def download_and_process_file(
if error:
return None, error
extension = ".pdf" if export_mime == "application/pdf" else ".txt"
extension = get_extension_from_mime(export_mime) or ".pdf"
else:
content_bytes, error = await client.download_file(file_id)
if error:
return None, error
# Preserve original file extension
extension = Path(file_name).suffix or ".bin"
extension = (
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
tmp_file.write(content_bytes)
@ -113,7 +302,10 @@ async def download_and_process_file(
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
connector_info["metadata"]["exported_as"] = "pdf"
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
connector_info["metadata"]["exported_as"] = (
export_ext.lstrip(".") if export_ext else "pdf"
)
connector_info["metadata"]["original_workspace_type"] = mime_type.split(
"."
)[-1]

View file

@ -7,11 +7,34 @@ GOOGLE_FOLDER = "application/vnd.google-apps.folder"
GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut"
EXPORT_FORMATS = {
GOOGLE_DOC: "application/pdf",
GOOGLE_SHEET: "application/pdf",
GOOGLE_DOC: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
GOOGLE_SLIDE: "application/pdf",
}
MIME_TO_EXTENSION: dict[str, str] = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"application/vnd.ms-excel": ".xls",
"application/msword": ".doc",
"application/vnd.ms-powerpoint": ".ppt",
"text/plain": ".txt",
"text/csv": ".csv",
"text/html": ".html",
"text/markdown": ".md",
"application/json": ".json",
"application/xml": ".xml",
"image/png": ".png",
"image/jpeg": ".jpg",
}
def get_extension_from_mime(mime_type: str) -> str | None:
"""Return a file extension (with leading dot) for a MIME type, or None."""
return MIME_TO_EXTENSION.get(mime_type)
def is_google_workspace_file(mime_type: str) -> bool:
"""Check if file is a Google Workspace file that needs export."""

View file

@ -914,6 +914,43 @@ class SharedMemory(BaseModel, TimestampMixin):
created_by = relationship("User")
class Folder(BaseModel, TimestampMixin):
__tablename__ = "folders"
name = Column(String(255), nullable=False, index=True)
position = Column(String(50), nullable=False, index=True)
parent_id = Column(
Integer,
ForeignKey("folders.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
parent = relationship("Folder", remote_side="Folder.id", backref="children")
search_space = relationship("SearchSpace", back_populates="folders")
created_by = relationship("User", back_populates="folders")
documents = relationship("Document", back_populates="folder", passive_deletes=True)
class Document(BaseModel, TimestampMixin):
__tablename__ = "documents"
@ -947,6 +984,13 @@ class Document(BaseModel, TimestampMixin):
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
folder_id = Column(
Integer,
ForeignKey("folders.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Track who created/uploaded this document
created_by_id = Column(
UUID(as_uuid=True),
@ -976,6 +1020,7 @@ class Document(BaseModel, TimestampMixin):
# Relationships
search_space = relationship("SearchSpace", back_populates="documents")
folder = relationship("Folder", back_populates="documents")
created_by = relationship("User", back_populates="documents")
connector = relationship("SearchSourceConnector", back_populates="documents")
chunks = relationship(
@ -1279,6 +1324,12 @@ class SearchSpace(BaseModel, TimestampMixin):
)
user = relationship("User", back_populates="search_spaces")
folders = relationship(
"Folder",
back_populates="search_space",
order_by="Folder.position",
cascade="all, delete-orphan",
)
documents = relationship(
"Document",
back_populates="search_space",
@ -1765,6 +1816,13 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True,
)
# Folders created by this user
folders = relationship(
"Folder",
back_populates="created_by",
passive_deletes=True,
)
# Image generations created by this user
image_generations = relationship(
"ImageGeneration",
@ -1867,6 +1925,13 @@ else:
passive_deletes=True,
)
# Folders created by this user
folders = relationship(
"Folder",
back_populates="created_by",
passive_deletes=True,
)
# Image generations created by this user
image_generations = relationship(
"ImageGeneration",

View file

@ -3,10 +3,19 @@ import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_identifier_hash(
document_type_value: str, unique_id: str, search_space_id: int
) -> str:
"""Return a stable SHA-256 hash from raw identity components."""
combined = f"{document_type_value}:{unique_id}:{search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
return compute_identifier_hash(
doc.document_type.value, doc.unique_id, doc.search_space_id
)
def compute_content_hash(doc: ConnectorDocument) -> str:

View file

@ -1,17 +1,29 @@
import asyncio
import contextlib
import hashlib
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
DocumentStatus,
DocumentType,
)
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_texts
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
@ -48,12 +60,166 @@ from app.indexing_pipeline.pipeline_logger import (
from app.utils.perf import get_perf_logger
@dataclass
class PlaceholderInfo:
"""Minimal info to create a placeholder document row for instant UI feedback.
These are created immediately when items are discovered (before content
extraction) so users see them in the UI via Zero sync right away.
"""
title: str
document_type: DocumentType
unique_id: str
search_space_id: int
connector_id: int | None
created_by_id: str
metadata: dict = field(default_factory=dict)
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def create_placeholder_documents(
self, placeholders: list[PlaceholderInfo]
) -> int:
"""Create placeholder document rows with pending status for instant UI feedback.
These rows appear immediately in the UI via Zero sync. They are later
updated by prepare_for_indexing() when actual content is available.
Returns the number of placeholders successfully created.
Failures are logged but never block the main indexing flow.
NOTE: This method commits on ``self.session`` so the rows become
visible to Zero sync immediately. Any pending ORM mutations on the
session are committed together, which is consistent with how other
mid-flow commits work in the indexing codebase (e.g. rename-only
updates in ``_should_skip_file``, ``migrate_legacy_docs``).
"""
if not placeholders:
return 0
_logger = logging.getLogger(__name__)
uid_hashes: dict[str, PlaceholderInfo] = {}
for p in placeholders:
try:
uid_hash = compute_identifier_hash(
p.document_type.value, p.unique_id, p.search_space_id
)
uid_hashes.setdefault(uid_hash, p)
except Exception:
_logger.debug(
"Skipping placeholder hash for %s", p.unique_id, exc_info=True
)
if not uid_hashes:
return 0
result = await self.session.execute(
select(Document.unique_identifier_hash).where(
Document.unique_identifier_hash.in_(list(uid_hashes.keys()))
)
)
existing_hashes: set[str] = set(result.scalars().all())
created = 0
for uid_hash, p in uid_hashes.items():
if uid_hash in existing_hashes:
continue
try:
content_hash = hashlib.sha256(
f"placeholder:{uid_hash}".encode()
).hexdigest()
document = Document(
title=p.title,
document_type=p.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=uid_hash,
document_metadata=p.metadata or {},
search_space_id=p.search_space_id,
connector_id=p.connector_id,
created_by_id=p.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
created += 1
except Exception:
_logger.debug("Skipping placeholder for %s", p.unique_id, exc_info=True)
if created > 0:
try:
await self.session.commit()
_logger.info(
"Created %d placeholder document(s) for instant UI feedback",
created,
)
except IntegrityError:
await self.session.rollback()
_logger.debug("Placeholder commit failed (race condition), continuing")
created = 0
return created
async def migrate_legacy_docs(
self, connector_docs: list[ConnectorDocument]
) -> None:
"""Migrate legacy Composio documents to their native Google type.
For each ConnectorDocument whose document_type has a Composio equivalent
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
update its unique_identifier_hash and document_type so that
prepare_for_indexing() can find it under the native hash.
"""
for doc in connector_docs:
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
if not legacy_type:
continue
legacy_hash = compute_identifier_hash(
legacy_type, doc.unique_id, doc.search_space_id
)
result = await self.session.execute(
select(Document).filter(Document.unique_identifier_hash == legacy_hash)
)
existing = result.scalars().first()
if existing is None:
continue
native_hash = compute_identifier_hash(
doc.document_type.value, doc.unique_id, doc.search_space_id
)
existing.unique_identifier_hash = native_hash
existing.document_type = doc.document_type
await self.session.commit()
async def index_batch(
self, connector_docs: list[ConnectorDocument], llm
) -> list[Document]:
"""Convenience method: prepare_for_indexing then index each document.
Indexers that need heartbeat callbacks or custom per-document logic
should call prepare_for_indexing() + index() directly instead.
"""
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
documents = await self.prepare_for_indexing(connector_docs)
results: list[Document] = []
for document in documents:
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
continue
result = await self.index(document, connector_doc, llm)
results.append(result)
return results
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
@ -106,6 +272,21 @@ class IndexingPipelineService:
log_document_requeued(ctx)
continue
dup_check = await self.session.execute(
select(Document.id).filter(
Document.content_hash == content_hash,
Document.id != existing.id,
)
)
if dup_check.scalars().first() is not None:
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.failed(
"Duplicate content — already indexed by another document"
)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
@ -200,13 +381,14 @@ class IndexingPipelineService:
)
t_step = time.perf_counter()
chunk_texts = chunk_text(
chunk_texts = await asyncio.to_thread(
chunk_text,
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
texts_to_embed = [content, *chunk_texts]
embeddings = embed_texts(texts_to_embed)
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
summary_embedding, *chunk_embeddings = embeddings
chunks = [
@ -268,3 +450,120 @@ class IndexingPipelineService:
await self.session.refresh(document)
return document
async def index_batch_parallel(
self,
connector_docs: list[ConnectorDocument],
get_llm: Callable[[AsyncSession], Awaitable],
*,
max_concurrency: int = 4,
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
heartbeat_interval: float = 30.0,
) -> tuple[list[Document], int, int]:
"""Index documents in parallel with bounded concurrency.
Phase 1 (serial): prepare_for_indexing using self.session.
Phase 2 (parallel): index each document in an isolated session,
bounded by a semaphore to avoid overwhelming APIs/DB.
"""
logger = logging.getLogger(__name__)
perf = get_perf_logger()
t_total = time.perf_counter()
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
documents = await self.prepare_for_indexing(connector_docs)
if not documents:
return [], 0, 0
from app.tasks.celery_tasks import get_celery_session_maker
sem = asyncio.Semaphore(max_concurrency)
lock = asyncio.Lock()
indexed_count = 0
failed_count = 0
results: list[Document] = []
last_heartbeat = time.time()
async def _index_one(document: Document) -> Document | Exception:
nonlocal indexed_count, failed_count, last_heartbeat
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
logger.warning(
"No matching ConnectorDocument for document %s, skipping",
document.id,
)
async with lock:
failed_count += 1
return document
async with sem:
session_maker = get_celery_session_maker()
async with session_maker() as isolated_session:
try:
refetched = await isolated_session.get(Document, document.id)
if refetched is None:
async with lock:
failed_count += 1
return document
llm = await get_llm(isolated_session)
iso_pipeline = IndexingPipelineService(isolated_session)
result = await iso_pipeline.index(refetched, connector_doc, llm)
async with lock:
if DocumentStatus.is_state(
result.status, DocumentStatus.READY
):
indexed_count += 1
else:
failed_count += 1
if on_heartbeat:
now = time.time()
if now - last_heartbeat >= heartbeat_interval:
await on_heartbeat(indexed_count)
last_heartbeat = now
return result
except Exception as exc:
logger.error(
"Parallel index failed for doc %s: %s",
document.id,
exc,
exc_info=True,
)
async with lock:
failed_count += 1
return exc
tasks = [_index_one(doc) for doc in documents]
t_parallel = time.perf_counter()
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
perf.info(
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
"indexed=%d failed=%d in %.3fs",
len(documents),
max_concurrency,
indexed_count,
failed_count,
time.perf_counter() - t_parallel,
)
for outcome in outcomes:
if isinstance(outcome, Document):
results.append(outcome)
elif isinstance(outcome, Exception):
pass
perf.info(
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
"indexed=%d failed=%d in %.3fs",
len(connector_docs),
len(documents),
indexed_count,
failed_count,
time.perf_counter() - t_total,
)
return results, indexed_count, failed_count

View file

@ -5,7 +5,7 @@ from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
_MAX_FETCH_CHUNKS_PER_DOC = 20
class ChucksHybridSearchRetriever:
@ -185,7 +185,7 @@ class ChucksHybridSearchRetriever:
- chunks: list[{chunk_id, content}] for citation-aware prompting
- document: {id, title, document_type, metadata}
"""
from sqlalchemy import func, select, text
from sqlalchemy import func, or_, select, text
from sqlalchemy.orm import joinedload
from app.config import config
@ -360,64 +360,81 @@ class ChucksHybridSearchRetriever:
if not doc_ids:
return []
# Fetch chunks for selected documents. We cap per document to avoid
# loading hundreds of chunks for a single large file while still
# ensuring the chunks that matched the RRF query are always included.
chunk_query = (
select(Chunk)
.options(joinedload(Chunk.document))
.join(Document, Chunk.document_id == Document.id)
.where(Document.id.in_(doc_ids))
.where(*base_conditions)
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunk_query)
raw_chunks = chunks_result.scalars().all()
# Collect document metadata from the small RRF result set (already
# loaded via joinedload) so the bulk chunk fetch can skip the expensive
# Document JOIN entirely.
matched_chunk_ids: set[int] = {
item["chunk_id"] for item in serialized_chunk_results
}
doc_meta_cache: dict[int, dict] = {}
for item in serialized_chunk_results:
did = item["document"]["id"]
if did not in doc_meta_cache:
doc_meta_cache[did] = item["document"]
doc_chunk_counts: dict[int, int] = {}
all_chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
all_chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# SQL-level per-document chunk limit using ROW_NUMBER().
# Avoids loading hundreds of chunks per large document only to
# discard them in Python.
numbered = (
select(
Chunk.id.label("chunk_id"),
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
# Assemble final doc-grouped results in the same order as doc_ids
matched_list = list(matched_chunk_ids)
if matched_list:
chunk_filter = or_(
numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC,
Chunk.id.in_(matched_list),
)
else:
chunk_filter = numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
chunk_query = (
select(Chunk.id, Chunk.content, Chunk.document_id)
.join(numbered, Chunk.id == numbered.c.chunk_id)
.where(chunk_filter)
.order_by(Chunk.document_id, Chunk.id)
)
t_fetch = time.perf_counter()
chunks_result = await self.db_session.execute(chunk_query)
fetched_chunks = chunks_result.all()
perf.debug(
"[chunk_search] chunk fetch in %.3fs rows=%d",
time.perf_counter() - t_fetch,
len(fetched_chunks),
)
# Assemble final doc-grouped results in the same order as doc_ids,
# using pre-cached doc metadata instead of joinedload.
doc_map: dict[int, dict] = {
doc_id: {
"document_id": doc_id,
"content": "",
"score": float(doc_scores.get(doc_id, 0.0)),
"chunks": [],
"document": {},
"source": None,
"matched_chunk_ids": [],
"document": doc_meta_cache.get(doc_id, {}),
"source": (doc_meta_cache.get(doc_id) or {}).get("document_type"),
}
for doc_id in doc_ids
}
for chunk in all_chunks:
doc = chunk.document
doc_id = doc.id
for row in fetched_chunks:
doc_id = row.document_id
if doc_id not in doc_map:
continue
doc_entry = doc_map[doc_id]
doc_entry["document"] = {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
}
doc_entry["source"] = (
doc.document_type.value if getattr(doc, "document_type", None) else None
)
doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content})
doc_entry["chunks"].append({"chunk_id": row.id, "content": row.content})
if row.id in matched_chunk_ids:
doc_entry["matched_chunk_ids"].append(row.id)
# Fill concatenated content (useful for reranking)
final_docs: list[dict] = []

View file

@ -4,7 +4,7 @@ from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
_MAX_FETCH_CHUNKS_PER_DOC = 20
class DocumentHybridSearchRetriever:
@ -289,57 +289,77 @@ class DocumentHybridSearchRetriever:
if not documents_with_scores:
return []
# Collect document IDs for chunk fetching
# Collect document IDs and pre-cache metadata from the small RRF
# result set so the bulk chunk fetch can skip joinedload entirely.
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
# Fetch chunks for these documents, capped per document to avoid
# loading hundreds of chunks for a single large file.
chunks_query = (
select(Chunk)
.options(joinedload(Chunk.document))
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
raw_chunks = chunks_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _MAX_FETCH_CHUNKS_PER_DOC:
chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results
doc_map: dict[int, dict] = {
doc.id: {
"document_id": doc.id,
"content": "",
"score": float(score),
"chunks": [],
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
doc_meta_cache: dict[int, dict] = {}
doc_score_cache: dict[int, float] = {}
doc_source_cache: dict[int, str | None] = {}
for doc, score in documents_with_scores:
doc_meta_cache[doc.id] = {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
}
for doc, score in documents_with_scores
doc_score_cache[doc.id] = float(score)
doc_source_cache[doc.id] = (
doc.document_type.value if getattr(doc, "document_type", None) else None
)
# SQL-level per-document chunk limit using ROW_NUMBER().
# Avoids loading hundreds of chunks per large document only to
# discard them in Python.
numbered = (
select(
Chunk.id.label("chunk_id"),
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
chunks_query = (
select(Chunk.id, Chunk.content, Chunk.document_id)
.join(numbered, Chunk.id == numbered.c.chunk_id)
.where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC)
.order_by(Chunk.document_id, Chunk.id)
)
t_fetch = time.perf_counter()
chunks_result = await self.db_session.execute(chunks_query)
fetched_chunks = chunks_result.all()
perf.debug(
"[doc_search] chunk fetch in %.3fs rows=%d",
time.perf_counter() - t_fetch,
len(fetched_chunks),
)
# Assemble doc-grouped results using pre-cached metadata.
doc_map: dict[int, dict] = {
doc_id: {
"document_id": doc_id,
"content": "",
"score": doc_score_cache.get(doc_id, 0.0),
"chunks": [],
"matched_chunk_ids": [],
"document": doc_meta_cache.get(doc_id, {}),
"source": doc_source_cache.get(doc_id),
}
for doc_id in doc_ids
}
for chunk in chunks:
doc_id = chunk.document_id
for row in fetched_chunks:
doc_id = row.document_id
if doc_id not in doc_map:
continue
doc_map[doc_id]["chunks"].append(
{"chunk_id": chunk.id, "content": chunk.content}
{"chunk_id": row.id, "content": row.content}
)
# Fill concatenated content (useful for reranking)

View file

@ -11,6 +11,7 @@ from .confluence_add_connector_route import router as confluence_add_connector_r
from .discord_add_connector_route import router as discord_add_connector_router
from .documents_routes import router as documents_router
from .editor_routes import router as editor_router
from .folders_routes import router as folders_router
from .google_calendar_add_connector_route import (
router as google_calendar_add_connector_router,
)
@ -51,6 +52,7 @@ router.include_router(search_spaces_router)
router.include_router(rbac_router) # RBAC routes for roles, members, invites
router.include_router(editor_router)
router.include_router(documents_router)
router.include_router(folders_router)
router.include_router(notes_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)

View file

@ -320,6 +320,7 @@ async def read_documents(
page_size: int = 50,
search_space_id: int | None = None,
document_types: str | None = None,
folder_id: int | str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
session: AsyncSession = Depends(get_async_session),
@ -391,6 +392,17 @@ async def read_documents(
query = query.filter(Document.document_type.in_(type_list))
count_query = count_query.filter(Document.document_type.in_(type_list))
# Filter by folder_id: "root" or "null" => root level (folder_id IS NULL),
# integer => specific folder, omitted => all documents
if folder_id is not None:
if str(folder_id).lower() in ("root", "null"):
query = query.filter(Document.folder_id.is_(None))
count_query = count_query.filter(Document.folder_id.is_(None))
else:
fid = int(folder_id)
query = query.filter(Document.folder_id == fid)
count_query = count_query.filter(Document.folder_id == fid)
total_result = await session.execute(count_query)
total = total_result.scalar() or 0
@ -451,6 +463,7 @@ async def read_documents(
created_at=doc.created_at,
updated_at=doc.updated_at,
search_space_id=doc.search_space_id,
folder_id=doc.folder_id,
created_by_id=doc.created_by_id,
created_by_name=created_by_name,
created_by_email=created_by_email,
@ -608,6 +621,7 @@ async def search_documents(
created_at=doc.created_at,
updated_at=doc.updated_at,
search_space_id=doc.search_space_id,
folder_id=doc.folder_id,
created_by_id=doc.created_by_id,
created_by_name=created_by_name,
created_by_email=created_by_email,
@ -978,6 +992,7 @@ async def read_document(
created_at=document.created_at,
updated_at=document.updated_at,
search_space_id=document.search_space_id,
folder_id=document.folder_id,
)
except HTTPException:
raise
@ -1036,6 +1051,7 @@ async def update_document(
created_at=db_document.created_at,
updated_at=db_document.updated_at,
search_space_id=db_document.search_space_id,
folder_id=db_document.folder_id,
)
except HTTPException:
raise

View file

@ -1,19 +1,42 @@
"""
Editor routes for document editing with markdown (Plate.js frontend).
Includes multi-format export (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text).
"""
import asyncio
import io
import logging
import os
import tempfile
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
import pypandoc
import typst
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import (
_FILE_EXTENSIONS,
_MEDIA_TYPES,
ExportFormat,
_normalize_latex_delimiters,
_strip_wrapping_code_fences,
)
from app.templates.export_helpers import (
get_html_css_path,
get_reference_docx_path,
get_typst_template_path,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
router = APIRouter()
@ -212,3 +235,162 @@ async def save_document(
"message": "Document saved and will be reindexed in the background",
"updated_at": document.updated_at.isoformat(),
}
@router.get("/search-spaces/{search_space_id}/documents/{document_id}/export")
async def export_document(
search_space_id: int,
document_id: int,
format: ExportFormat = Query(
ExportFormat.PDF,
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Export a document in the requested format (reuses the report export pipeline)."""
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
.filter(
Document.id == document_id,
Document.search_space_id == search_space_id,
)
)
document = result.scalars().first()
if not document:
raise HTTPException(status_code=404, detail="Document not found")
# Resolve markdown content (same priority as editor-content endpoint)
markdown_content: str | None = document.source_markdown
if markdown_content is None and document.blocknote_document:
from app.utils.blocknote_to_markdown import blocknote_to_markdown
markdown_content = blocknote_to_markdown(document.blocknote_document)
if markdown_content is None:
chunks = sorted(document.chunks, key=lambda c: c.id)
if chunks:
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
if not markdown_content or not markdown_content.strip():
raise HTTPException(status_code=400, detail="Document has no content to export")
markdown_content = _strip_wrapping_code_fences(markdown_content)
markdown_content = _normalize_latex_delimiters(markdown_content)
doc_title = document.title or "Document"
formatted_date = (
document.created_at.strftime("%B %d, %Y") if document.created_at else ""
)
input_fmt = "gfm+tex_math_dollars"
meta_args = ["-M", f"title:{doc_title}", "-M", f"date:{formatted_date}"]
def _convert_and_read() -> bytes:
if format == ExportFormat.PDF:
typst_template = str(get_typst_template_path())
typst_markup: str = pypandoc.convert_text(
markdown_content,
"typst",
format=input_fmt,
extra_args=[
"--standalone",
f"--template={typst_template}",
"-V",
"mainfont:Libertinus Serif",
"-V",
"codefont:DejaVu Sans Mono",
*meta_args,
],
)
return typst.compile(typst_markup.encode("utf-8"))
if format == ExportFormat.DOCX:
return _pandoc_to_tempfile(
format.value,
[
"--standalone",
f"--reference-doc={get_reference_docx_path()}",
*meta_args,
],
)
if format == ExportFormat.HTML:
html_str: str = pypandoc.convert_text(
markdown_content,
"html5",
format=input_fmt,
extra_args=[
"--standalone",
"--embed-resources",
f"--css={get_html_css_path()}",
"--syntax-highlighting=pygments",
*meta_args,
],
)
return html_str.encode("utf-8")
if format == ExportFormat.EPUB:
return _pandoc_to_tempfile("epub3", ["--standalone", *meta_args])
if format == ExportFormat.ODT:
return _pandoc_to_tempfile("odt", ["--standalone", *meta_args])
if format == ExportFormat.LATEX:
tex_str: str = pypandoc.convert_text(
markdown_content,
"latex",
format=input_fmt,
extra_args=["--standalone", *meta_args],
)
return tex_str.encode("utf-8")
plain_str: str = pypandoc.convert_text(
markdown_content,
"plain",
format=input_fmt,
extra_args=["--wrap=auto", "--columns=80"],
)
return plain_str.encode("utf-8")
def _pandoc_to_tempfile(output_format: str, extra_args: list[str]) -> bytes:
fd, tmp_path = tempfile.mkstemp(suffix=f".{output_format}")
os.close(fd)
try:
pypandoc.convert_text(
markdown_content,
output_format,
format=input_fmt,
extra_args=extra_args,
outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_path)
try:
loop = asyncio.get_running_loop()
output = await loop.run_in_executor(None, _convert_and_read)
except Exception as e:
logger.exception("Document export failed")
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
safe_title = (
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title).strip()[:80]
or "document"
)
ext = _FILE_EXTENSIONS[format]
return StreamingResponse(
io.BytesIO(output),
media_type=_MEDIA_TYPES[format],
headers={"Content-Disposition": f'attachment; filename="{safe_title}.{ext}"'},
)

View file

@ -0,0 +1,516 @@
"""API routes for folder CRUD, move, reorder, and document move operations."""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Document, Folder, Permission, User, get_async_session
from app.schemas import (
BulkDocumentMove,
DocumentMove,
FolderBreadcrumb,
FolderCreate,
FolderMove,
FolderRead,
FolderReorder,
FolderUpdate,
)
from app.services.folder_service import (
check_no_circular_reference,
generate_folder_position,
get_folder_subtree_ids,
get_subtree_max_depth,
validate_folder_depth,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
@router.post("/folders", response_model=FolderRead)
async def create_folder(
request: FolderCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new folder. Requires DOCUMENTS_CREATE permission."""
try:
await check_permission(
session,
user,
request.search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create folders in this search space",
)
if request.parent_id is not None:
parent = await session.get(Folder, request.parent_id)
if not parent:
raise HTTPException(status_code=404, detail="Parent folder not found")
if parent.search_space_id != request.search_space_id:
raise HTTPException(
status_code=400,
detail="Parent folder belongs to a different search space",
)
await validate_folder_depth(session, request.parent_id)
position = await generate_folder_position(
session, request.search_space_id, request.parent_id
)
folder = Folder(
name=request.name,
position=position,
parent_id=request.parent_id,
search_space_id=request.search_space_id,
created_by_id=user.id,
)
session.add(folder)
await session.commit()
await session.refresh(folder)
return folder
except HTTPException:
raise
except Exception as e:
await session.rollback()
if "uq_folder_space_parent_name" in str(e):
raise HTTPException(
status_code=409,
detail="A folder with this name already exists at this location",
) from e
raise HTTPException(
status_code=500, detail=f"Failed to create folder: {e!s}"
) from e
@router.get("/folders", response_model=list[FolderRead])
async def list_folders(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
)
result = await session.execute(
select(Folder)
.where(Folder.search_space_id == search_space_id)
.order_by(Folder.position)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to list folders: {e!s}"
) from e
@router.get("/folders/{folder_id}", response_model=FolderRead)
async def get_folder(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a single folder. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
)
return folder
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to get folder: {e!s}"
) from e
@router.get("/folders/{folder_id}/breadcrumb", response_model=list[FolderBreadcrumb])
async def get_folder_breadcrumb(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
)
result = await session.execute(
text("""
WITH RECURSIVE ancestors AS (
SELECT id, name, parent_id, 0 AS depth
FROM folders WHERE id = :folder_id
UNION ALL
SELECT f.id, f.name, f.parent_id, a.depth + 1
FROM folders f JOIN ancestors a ON f.id = a.parent_id
)
SELECT id, name FROM ancestors ORDER BY depth DESC;
"""),
{"folder_id": folder_id},
)
rows = result.fetchall()
return [FolderBreadcrumb(id=row.id, name=row.name) for row in rows]
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to get breadcrumb: {e!s}"
) from e
@router.put("/folders/{folder_id}", response_model=FolderRead)
async def update_folder(
folder_id: int,
request: FolderUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Rename a folder. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update folders in this search space",
)
folder.name = request.name
await session.commit()
await session.refresh(folder)
return folder
except HTTPException:
raise
except Exception as e:
await session.rollback()
if "uq_folder_space_parent_name" in str(e):
raise HTTPException(
status_code=409,
detail="A folder with this name already exists at this location",
) from e
raise HTTPException(
status_code=500, detail=f"Failed to update folder: {e!s}"
) from e
@router.put("/folders/{folder_id}/move", response_model=FolderRead)
async def move_folder(
folder_id: int,
request: FolderMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move folders in this search space",
)
if request.new_parent_id is not None:
new_parent = await session.get(Folder, request.new_parent_id)
if not new_parent:
raise HTTPException(
status_code=404, detail="Target parent folder not found"
)
if new_parent.search_space_id != folder.search_space_id:
raise HTTPException(
status_code=400,
detail="Cannot move folder to a different search space",
)
await check_no_circular_reference(session, folder_id, request.new_parent_id)
subtree_depth = await get_subtree_max_depth(session, folder_id)
await validate_folder_depth(session, request.new_parent_id, subtree_depth)
position = await generate_folder_position(
session, folder.search_space_id, request.new_parent_id
)
folder.parent_id = request.new_parent_id
folder.position = position
await session.commit()
await session.refresh(folder)
return folder
except HTTPException:
raise
except Exception as e:
await session.rollback()
if "uq_folder_space_parent_name" in str(e):
raise HTTPException(
status_code=409,
detail="A folder with this name already exists at the target location",
) from e
raise HTTPException(
status_code=500, detail=f"Failed to move folder: {e!s}"
) from e
@router.put("/folders/{folder_id}/reorder", response_model=FolderRead)
async def reorder_folder(
folder_id: int,
request: FolderReorder,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to reorder folders in this search space",
)
position = await generate_folder_position(
session,
folder.search_space_id,
folder.parent_id,
before_position=request.before_position,
after_position=request.after_position,
)
folder.position = position
await session.commit()
await session.refresh(folder)
return folder
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to reorder folder: {e!s}"
) from e
@router.delete("/folders/{folder_id}")
async def delete_folder(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete a folder and cascade-delete subfolders. Documents are async-deleted via Celery."""
try:
folder = await session.get(Folder, folder_id)
if not folder:
raise HTTPException(status_code=404, detail="Folder not found")
await check_permission(
session,
user,
folder.search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete folders in this search space",
)
subtree_ids = await get_folder_subtree_ids(session, folder_id)
doc_result = await session.execute(
select(Document.id).where(
Document.folder_id.in_(subtree_ids),
Document.status["state"].as_string() != "deleting",
)
)
document_ids = list(doc_result.scalars().all())
if document_ids:
await session.execute(
Document.__table__.update()
.where(Document.id.in_(document_ids))
.values(status={"state": "deleting"})
)
await session.commit()
await session.execute(Folder.__table__.delete().where(Folder.id == folder_id))
await session.commit()
if document_ids:
try:
from app.tasks.celery_tasks.document_tasks import (
delete_folder_documents_task,
)
delete_folder_documents_task.delay(document_ids)
except Exception as err:
await session.execute(
Document.__table__.update()
.where(Document.id.in_(document_ids))
.values(status={"state": "ready"})
)
await session.commit()
raise HTTPException(
status_code=503,
detail="Folder deleted but document cleanup could not be queued. Documents have been restored.",
) from err
return {
"message": "Folder deleted successfully",
"documents_queued_for_deletion": len(document_ids),
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to delete folder: {e!s}"
) from e
@router.put("/documents/{document_id}/move")
async def move_document(
document_id: int,
request: DocumentMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
result = await session.execute(
select(Document).filter(Document.id == document_id)
)
document = result.scalars().first()
if not document:
raise HTTPException(status_code=404, detail="Document not found")
await check_permission(
session,
user,
document.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space",
)
if request.folder_id is not None:
target = await session.get(Folder, request.folder_id)
if not target:
raise HTTPException(status_code=404, detail="Target folder not found")
if target.search_space_id != document.search_space_id:
raise HTTPException(
status_code=400,
detail="Cannot move document to a folder in a different search space",
)
document.folder_id = request.folder_id
await session.commit()
return {"message": "Document moved successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to move document: {e!s}"
) from e
@router.put("/documents/bulk-move")
async def bulk_move_documents(
request: BulkDocumentMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
if not request.document_ids:
raise HTTPException(status_code=400, detail="No document IDs provided")
result = await session.execute(
select(Document).filter(Document.id.in_(request.document_ids))
)
documents = result.scalars().all()
if not documents:
raise HTTPException(status_code=404, detail="No documents found")
search_space_ids = {doc.search_space_id for doc in documents}
for ss_id in search_space_ids:
await check_permission(
session,
user,
ss_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space",
)
if request.folder_id is not None:
target = await session.get(Folder, request.folder_id)
if not target:
raise HTTPException(status_code=404, detail="Target folder not found")
mismatched = [
doc.id
for doc in documents
if doc.search_space_id != target.search_space_id
]
if mismatched:
raise HTTPException(
status_code=400,
detail="Cannot move documents to a folder in a different search space",
)
await session.execute(
Document.__table__.update()
.where(Document.id.in_(request.document_ids))
.values(folder_id=request.folder_id)
)
await session.commit()
return {"message": f"{len(request.document_ids)} documents moved successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to move documents: {e!s}"
) from e

View file

@ -2329,7 +2329,7 @@ async def run_google_drive_indexing(
try:
from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files,
index_google_drive_single_file,
index_google_drive_selected_files,
)
# Parse the structured data
@ -2402,25 +2402,27 @@ async def run_google_drive_indexing(
exc_info=True,
)
# Index each individual file
for file in items.files:
# Index all selected files together via the parallel pipeline
if items.files:
try:
indexed_count, error_message = await index_google_drive_single_file(
file_tuples = [(f.id, f.name) for f in items.files]
(
indexed_count,
_skipped,
file_errors,
) = await index_google_drive_selected_files(
session,
connector_id,
search_space_id,
user_id,
file_id=file.id,
file_name=file.name,
files=file_tuples,
)
if error_message:
errors.append(f"File '{file.name}': {error_message}")
else:
total_indexed += indexed_count
total_indexed += indexed_count
errors.extend(file_errors)
except Exception as e:
errors.append(f"File '{file.name}': {e!s}")
errors.append(f"File batch indexing: {e!s}")
logger.error(
f"Error indexing file {file.name} ({file.id}): {e}",
f"Error batch indexing files: {e}",
exc_info=True,
)

View file

@ -22,6 +22,16 @@ from .documents import (
ExtensionDocumentMetadata,
PaginatedResponse,
)
from .folders import (
BulkDocumentMove,
DocumentMove,
FolderBreadcrumb,
FolderCreate,
FolderMove,
FolderRead,
FolderReorder,
FolderUpdate,
)
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
from .image_generation import (
GlobalImageGenConfigRead,
@ -109,6 +119,8 @@ from .video_presentations import (
)
__all__ = [
# Folder schemas
"BulkDocumentMove",
# Chat schemas (assistant-ui integration)
"ChatMessage",
# Chunk schemas
@ -119,6 +131,7 @@ __all__ = [
"DefaultSystemInstructionsResponse",
# Document schemas
"DocumentBase",
"DocumentMove",
"DocumentRead",
"DocumentStatusBatchResponse",
"DocumentStatusItemRead",
@ -132,6 +145,12 @@ __all__ = [
"DriveItem",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"FolderBreadcrumb",
"FolderCreate",
"FolderMove",
"FolderRead",
"FolderReorder",
"FolderUpdate",
"GlobalImageGenConfigRead",
"GlobalNewLLMConfigRead",
"GoogleDriveIndexRequest",

View file

@ -59,6 +59,7 @@ class DocumentRead(BaseModel):
created_at: datetime
updated_at: datetime | None
search_space_id: int
folder_id: int | None = None
created_by_id: UUID | None = None # User who created/uploaded this document
created_by_name: str | None = None
created_by_email: str | None = None
@ -89,6 +90,7 @@ class DocumentTitleRead(BaseModel):
id: int
title: str
document_type: DocumentType
folder_id: int | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -0,0 +1,52 @@
"""Pydantic schemas for folder CRUD, move, and reorder operations."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class FolderCreate(BaseModel):
name: str = Field(max_length=255, min_length=1)
parent_id: int | None = None
search_space_id: int
class FolderUpdate(BaseModel):
name: str = Field(max_length=255, min_length=1)
class FolderMove(BaseModel):
new_parent_id: int | None = None
class FolderReorder(BaseModel):
before_position: str | None = None
after_position: str | None = None
class FolderRead(BaseModel):
id: int
name: str
position: str
parent_id: int | None
search_space_id: int
created_by_id: UUID | None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class FolderBreadcrumb(BaseModel):
id: int
name: str
class DocumentMove(BaseModel):
folder_id: int | None = None
class BulkDocumentMove(BaseModel):
document_ids: list[int]
folder_id: int | None = None

View file

@ -0,0 +1,158 @@
"""Folder service: depth validation, circular reference checks, and position generation."""
from fastapi import HTTPException
from fractional_indexing import generate_key_between
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Folder
MAX_FOLDER_DEPTH = 8
async def get_folder_depth(session: AsyncSession, folder_id: int) -> int:
"""Return the depth of a folder (root-level = 1) using a recursive CTE."""
result = await session.execute(
text("""
WITH RECURSIVE ancestors AS (
SELECT id, parent_id, 1 AS depth
FROM folders
WHERE id = :folder_id
UNION ALL
SELECT f.id, f.parent_id, a.depth + 1
FROM folders f
JOIN ancestors a ON f.id = a.parent_id
)
SELECT MAX(depth) FROM ancestors;
"""),
{"folder_id": folder_id},
)
return result.scalar() or 0
async def get_subtree_max_depth(session: AsyncSession, folder_id: int) -> int:
"""Return the maximum depth of any descendant below folder_id (0 if leaf)."""
result = await session.execute(
text("""
WITH RECURSIVE descendants AS (
SELECT id, 0 AS depth
FROM folders
WHERE parent_id = :folder_id
UNION ALL
SELECT f.id, d.depth + 1
FROM folders f
JOIN descendants d ON f.parent_id = d.id
)
SELECT COALESCE(MAX(depth), -1) FROM descendants;
"""),
{"folder_id": folder_id},
)
val = result.scalar()
return (val + 1) if val is not None and val >= 0 else 0
async def validate_folder_depth(
session: AsyncSession,
parent_id: int | None,
subtree_depth: int = 0,
) -> None:
"""Raise 400 if placing a folder (with subtree) under parent_id would exceed MAX_FOLDER_DEPTH."""
if parent_id is None:
parent_depth = 0
else:
parent_depth = await get_folder_depth(session, parent_id)
total = parent_depth + 1 + subtree_depth
if total > MAX_FOLDER_DEPTH:
raise HTTPException(
status_code=400,
detail=f"Maximum folder nesting depth is {MAX_FOLDER_DEPTH}. "
f"This operation would result in depth {total}.",
)
async def check_no_circular_reference(
session: AsyncSession,
folder_id: int,
new_parent_id: int | None,
) -> None:
"""Raise 400 if new_parent_id is folder_id itself or a descendant of folder_id."""
if new_parent_id is None:
return
if new_parent_id == folder_id:
raise HTTPException(
status_code=400,
detail="A folder cannot be moved into itself.",
)
result = await session.execute(
text("""
WITH RECURSIVE ancestors AS (
SELECT id, parent_id
FROM folders
WHERE id = :new_parent_id
UNION ALL
SELECT f.id, f.parent_id
FROM folders f
JOIN ancestors a ON f.id = a.parent_id
)
SELECT 1 FROM ancestors WHERE id = :folder_id LIMIT 1;
"""),
{"new_parent_id": new_parent_id, "folder_id": folder_id},
)
if result.scalar() is not None:
raise HTTPException(
status_code=400,
detail="Cannot move a folder into one of its own descendants.",
)
async def generate_folder_position(
session: AsyncSession,
search_space_id: int,
parent_id: int | None,
before_position: str | None = None,
after_position: str | None = None,
) -> str:
"""Generate a fractional index key for ordering a folder among its siblings.
- Default (no before/after): append after last sibling
- Prepend: before_position=None, after_position=first sibling position
- Insert between: both positions provided
"""
if before_position is not None or after_position is not None:
return generate_key_between(before_position, after_position)
# Append after last sibling
query = (
select(Folder.position)
.where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
)
.order_by(Folder.position.desc())
.limit(1)
)
result = await session.execute(query)
last_position = result.scalar()
return generate_key_between(last_position, None)
async def get_folder_subtree_ids(session: AsyncSession, folder_id: int) -> list[int]:
"""Return all folder IDs in the subtree rooted at folder_id (inclusive)."""
result = await session.execute(
text("""
WITH RECURSIVE subtree AS (
SELECT id FROM folders WHERE id = :folder_id
UNION ALL
SELECT f.id FROM folders f JOIN subtree s ON f.parent_id = s.id
)
SELECT id FROM subtree;
"""),
{"folder_id": folder_id},
)
return list(result.scalars().all())

View file

@ -209,8 +209,8 @@ class GoogleCalendarKBSyncService:
)
calendar_id = (document.document_metadata or {}).get(
"calendar_id", "primary"
)
"calendar_id"
) or "primary"
live_event = await loop.run_in_executor(
None,
lambda: (

View file

@ -133,6 +133,51 @@ async def _delete_document_background(document_id: int) -> None:
await session.commit()
@celery_app.task(
name="delete_folder_documents_background",
bind=True,
autoretry_for=(Exception,),
retry_backoff=True,
retry_backoff_max=300,
max_retries=5,
)
def delete_folder_documents_task(self, document_ids: list[int]):
"""Celery task to batch-delete documents orphaned by folder deletion."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_folder_documents(document_ids))
finally:
loop.close()
async def _delete_folder_documents(document_ids: list[int]) -> None:
"""Delete chunks in batches, then document rows for each orphaned document."""
from sqlalchemy import delete as sa_delete, select
from app.db import Chunk, Document
async with get_celery_session_maker()() as session:
batch_size = 500
for doc_id in document_ids:
while True:
chunk_ids_result = await session.execute(
select(Chunk.id)
.where(Chunk.document_id == doc_id)
.limit(batch_size)
)
chunk_ids = chunk_ids_result.scalars().all()
if not chunk_ids:
break
await session.execute(sa_delete(Chunk).where(Chunk.id.in_(chunk_ids)))
await session.commit()
doc = await session.get(Document, doc_id)
if doc:
await session.delete(doc)
await session.commit()
@celery_app.task(
name="delete_search_space_background",
bind=True,

View file

@ -9,6 +9,7 @@ Supports loading LLM configurations from:
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
"""
import ast
import asyncio
import contextlib
import gc
@ -36,10 +37,6 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
)
from app.db import (
ChatVisibility,
Document,
@ -212,7 +209,7 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
async def _stream_agent_events(
@ -281,6 +278,8 @@ async def _stream_agent_events(
if event_type == "on_chat_model_stream":
if active_tool_depth > 0:
continue # Suppress inner-tool LLM tokens from leaking into chat
if "surfsense:internal" in event.get("tags", []):
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
content = chunk.content
@ -319,19 +318,114 @@ async def _stream_agent_events(
tool_step_ids[run_id] = tool_step_id
last_active_step_id = tool_step_id
if tool_name == "search_knowledge_base":
query = (
tool_input.get("query", "")
if tool_name == "ls":
ls_path = (
tool_input.get("path", "/")
if isinstance(tool_input, dict)
else str(tool_input)
)
last_active_step_title = "Searching knowledge base"
last_active_step_title = "Listing files"
last_active_step_items = [ls_path]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Listing files",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "read_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Reading file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Reading file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "write_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Writing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Writing file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "edit_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Editing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Editing file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "glob":
pat = (
tool_input.get("pattern", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
base_path = (
tool_input.get("path", "/") if isinstance(tool_input, dict) else "/"
)
last_active_step_title = "Searching files"
last_active_step_items = [f"{pat} in {base_path}"]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Searching files",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "grep":
pat = (
tool_input.get("pattern", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
grep_path = (
tool_input.get("path", "") if isinstance(tool_input, dict) else ""
)
display_pat = pat[:60] + ("" if len(pat) > 60 else "")
last_active_step_title = "Searching content"
last_active_step_items = [
f"Query: {query[:100]}{'...' if len(query) > 100 else ''}"
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Searching knowledge base",
title="Searching content",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "save_document":
doc_title = (
tool_input.get("title", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_title = doc_title[:60] + ("" if len(doc_title) > 60 else "")
last_active_step_title = "Saving document"
last_active_step_items = [display_title]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Saving document",
status="in_progress",
items=last_active_step_items,
)
@ -441,10 +535,22 @@ async def _stream_agent_events(
else streaming_service.generate_tool_call_id()
)
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
_safe_input: dict[str, Any] = {}
for _k, _v in tool_input.items():
try:
json.dumps(_v)
_safe_input[_k] = _v
except (TypeError, ValueError, OverflowError):
pass
else:
_safe_input = {"input": tool_input}
yield streaming_service.format_tool_input_available(
tool_call_id,
tool_name,
tool_input if isinstance(tool_input, dict) else {"input": tool_input},
_safe_input,
)
elif event_type == "on_tool_end":
@ -475,16 +581,55 @@ async def _stream_agent_events(
)
completed_step_ids.add(original_step_id)
if tool_name == "search_knowledge_base":
result_info = "Search completed"
if isinstance(tool_output, dict):
result_len = tool_output.get("result_length", 0)
if result_len > 0:
result_info = f"Found relevant information ({result_len} chars)"
completed_items = [*last_active_step_items, result_info]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching knowledge base",
title="Reading file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Writing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "edit_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Editing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "glob":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching files",
status="completed",
items=last_active_step_items,
)
elif tool_name == "grep":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching content",
status="completed",
items=last_active_step_items,
)
elif tool_name == "save_document":
result_str = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
is_error = "Error" in result_str
completed_items = [
*last_active_step_items,
result_str[:80] if is_error else "Saved to knowledge base",
]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Saving document",
status="completed",
items=completed_items,
)
@ -690,14 +835,23 @@ async def _stream_agent_events(
ls_output = str(tool_output) if tool_output else ""
file_names: list[str] = []
if ls_output:
for line in ls_output.strip().split("\n"):
line = line.strip()
if line:
name = line.rstrip("/").split("/")[-1]
if name and len(name) <= 40:
file_names.append(name)
elif name:
file_names.append(name[:37] + "...")
paths: list[str] = []
try:
parsed = ast.literal_eval(ls_output)
if isinstance(parsed, list):
paths = [str(p) for p in parsed]
except (ValueError, SyntaxError):
paths = [
line.strip()
for line in ls_output.strip().split("\n")
if line.strip()
]
for p in paths:
name = p.rstrip("/").split("/")[-1]
if name and len(name) <= 40:
file_names.append(name)
elif name:
file_names.append(name[:37] + "...")
if file_names:
if len(file_names) <= 5:
completed_items = [f"[{name}]" for name in file_names]
@ -708,7 +862,7 @@ async def _stream_agent_events(
completed_items = ["No files found"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Exploring files",
title="Listing files",
status="completed",
items=completed_items,
)
@ -832,14 +986,6 @@ async def _stream_agent_events(
f"Scrape failed: {error_msg}",
"error",
)
elif tool_name == "search_knowledge_base":
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
@ -973,6 +1119,19 @@ async def _stream_agent_events(
items=last_active_step_items,
)
elif (
event_type == "on_custom_event" and event.get("name") == "document_created"
):
data = event.get("data", {})
if data.get("id"):
yield streaming_service.format_data(
"documents-updated",
{
"action": "created",
"document": data,
},
)
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -995,38 +1154,6 @@ async def _stream_agent_events(
yield streaming_service.format_interrupt_request(result.interrupt_value)
def _try_persist_and_delete_sandbox(
thread_id: int,
sandbox_files: list[str],
) -> None:
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if not is_sandbox_enabled():
return
async def _run() -> None:
try:
await persist_and_delete_sandbox(thread_id, sandbox_files)
except Exception:
logging.getLogger(__name__).warning(
"persist_and_delete_sandbox failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_run())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def stream_new_chat(
user_query: str,
search_space_id: int,
@ -1141,22 +1268,6 @@ async def stream_new_chat(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
@ -1170,7 +1281,6 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
disabled_tools=disabled_tools,
)
_perf_log.info(
@ -1531,8 +1641,6 @@ async def stream_new_chat(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
@ -1541,7 +1649,7 @@ async def stream_new_chat(
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
agent = llm = connector_service = None
input_state = stream_result = None
session = None
@ -1627,22 +1735,6 @@ async def stream_resume_chat(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
@ -1657,7 +1749,6 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -1742,15 +1833,13 @@ async def stream_resume_chat(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
agent = llm = connector_service = None
stream_result = None
session = None

View file

@ -1,49 +1,77 @@
"""
Confluence connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
"""Confluence connector indexer using the unified parallel indexing pipeline."""
import contextlib
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
full_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Confluence page dict to a ConnectorDocument."""
page_id = page.get("id", "")
page_title = page.get("title", "")
space_id = page.get("spaceId", "")
comment_count = len(page.get("comments", []))
metadata = {
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
fallback_summary = (
f"Confluence Page: {page_title}\n\nSpace ID: {space_id}\n\n{full_content}"
)
return ConnectorDocument(
title=page_title,
source_markdown=full_content,
unique_id=page_id,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_confluence_pages(
session: AsyncSession,
connector_id: int,
@ -53,26 +81,9 @@ async def index_confluence_pages(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
"""
Index Confluence pages and comments.
Args:
session: Database session
connector_id: ID of the Confluence connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
) -> tuple[int, int, str | None]:
"""Index Confluence pages and comments."""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="confluence_pages_indexing",
source="connector_indexing_task",
@ -86,7 +97,6 @@ async def index_confluence_pages(
)
try:
# Get the connector from the database
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR
)
@ -98,9 +108,8 @@ async def index_confluence_pages(
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return 0, f"Connector with ID {connector_id} not found"
return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Confluence OAuth client
await task_logger.log_task_progress(
log_entry,
f"Initializing Confluence OAuth client for connector {connector_id}",
@ -114,7 +123,6 @@ async def index_confluence_pages(
)
)
# Calculate date range
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
@ -129,19 +137,14 @@ async def index_confluence_pages(
},
)
# Get pages within date range
try:
pages, error = await confluence_client.get_pages_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
if error:
# Don't treat "No pages found" as an error that should stop indexing
if "No pages found" in error:
logger.info(f"No Confluence pages found: {error}")
logger.info(
"No pages found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
@ -156,11 +159,10 @@ async def index_confluence_pages(
f"No Confluence pages found in date range {start_date_str} to {end_date_str}",
{"pages_found": 0},
)
# Close client before returning
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, None
return 0, 0, None
else:
logger.error(f"Failed to get Confluence pages: {error}")
await task_logger.log_task_failure(
@ -169,42 +171,62 @@ async def index_confluence_pages(
"API Error",
{"error_type": "APIError"},
)
# Close client on error
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, f"Failed to get Confluence pages: {error}"
return 0, 0, f"Failed to get Confluence pages: {error}"
logger.info(f"Retrieved {len(pages)} pages from Confluence API")
except Exception as e:
logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True)
# Close client on error
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, f"Error fetching Confluence pages: {e!s}"
return 0, 0, f"Error fetching Confluence pages: {e!s}"
if not pages:
logger.info("No Confluence pages found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=page.get("title", ""),
document_type=DocumentType.CONFLUENCE_CONNECTOR,
unique_id=page.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"page_id": page.get("id", ""),
"connector_id": connector_id,
"connector_type": "Confluence",
},
)
for page in pages
if page.get("id") and page.get("title")
]
await pipeline.create_placeholder_documents(placeholders)
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
documents_indexed = 0
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
connector_docs: list[ConnectorDocument] = []
for page in pages:
try:
page_id = page.get("id")
page_title = page.get("title", "")
space_id = page.get("spaceId", "")
page.get("spaceId", "")
if not page_id or not page_title:
logger.warning(
@ -213,12 +235,10 @@ async def index_confluence_pages(
documents_skipped += 1
continue
# Extract page content
page_content = ""
if page.get("body") and page["body"].get("storage"):
page_content = page["body"]["storage"].get("value", "")
# Add comments to content
comments = page.get("comments", [])
comments_content = ""
if comments:
@ -235,61 +255,25 @@ async def index_confluence_pages(
comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
# Combine page content with comments
full_content = f"# {page_title}\n\n{page_content}{comments_content}"
if not full_content.strip():
if not page_content.strip() and not comments:
logger.warning(f"Skipping page with no content: {page_title}")
documents_skipped += 1
continue
# Generate unique identifier hash for this Confluence page
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id
doc = _build_connector_doc(
page,
full_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(full_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(comments)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
pages_to_process.append(
{
"document": existing_document,
"is_new": False,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
session, compute_content_hash(doc)
)
if duplicate_by_content:
@ -302,151 +286,30 @@ async def index_confluence_pages(
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
document_metadata={
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
for item in pages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["full_content"], user_llm, document_metadata
)
else:
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content with comments
chunks = await create_document_chunks(item["full_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_id": item["page_id"],
"page_title": item["page_title"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Confluence pages processed so far"
)
await session.commit()
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error processing page {item.get('page_title', 'Unknown')}: {e!s}",
exc_info=True,
f"Error building ConnectorDocument for page: {e!s}", exc_info=True
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue # Skip this page and continue with others
documents_skipped += 1
continue
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info(
f"Final commit: Total {documents_indexed} Confluence pages processed"
)
@ -456,7 +319,6 @@ async def index_confluence_pages(
"Successfully committed all Confluence document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
@ -467,11 +329,9 @@ async def index_confluence_pages(
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -479,7 +339,6 @@ async def index_confluence_pages(
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Confluence indexing for connector {connector_id}",
@ -490,22 +349,19 @@ async def index_confluence_pages(
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"Confluence indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
# Close the client connection
if confluence_client:
await confluence_client.close()
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
# Close client if it exists
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
@ -516,10 +372,9 @@ async def index_confluence_pages(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
# Close client if it exists
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
@ -530,4 +385,4 @@ async def index_confluence_pages(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True)
return 0, f"Failed to index Confluence pages: {e!s}"
return 0, 0, f"Failed to index Confluence pages: {e!s}"

View file

@ -1,12 +1,10 @@
"""
Google Calendar connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -15,29 +13,25 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_calendar_connector import GoogleCalendarConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
)
from .base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
parse_date_flexible,
safe_set_chunks,
update_connector_last_indexed,
)
@ -46,13 +40,58 @@ ACCEPTED_CALENDAR_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
}
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
event: dict,
event_markdown: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Google Calendar API event dict to a ConnectorDocument."""
event_id = event.get("id", "")
event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
metadata = {
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
fallback_summary = f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
return ConnectorDocument(
title=event_summary,
source_markdown=event_markdown,
unique_id=event_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_calendar_events(
session: AsyncSession,
connector_id: int,
@ -82,7 +121,6 @@ async def index_google_calendar_events(
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="google_calendar_events_indexing",
source="connector_indexing_task",
@ -96,7 +134,7 @@ async def index_google_calendar_events(
)
try:
# Accept both native and Composio Calendar connectors
# ── Connector lookup ──────────────────────────────────────────
connector = None
for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct)
@ -112,7 +150,7 @@ async def index_google_calendar_events(
)
return 0, 0, f"Connector with ID {connector_id} not found"
# Build credentials based on connector type
# ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -184,6 +222,7 @@ async def index_google_calendar_events(
)
return 0, 0, "Google Calendar credentials not found in connector config"
# ── Calendar client init ──────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Google Calendar client for connector {connector_id}",
@ -203,36 +242,26 @@ async def index_google_calendar_events(
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range
# For calendar connectors, allow future dates to index upcoming events
# ── Date range calculation ────────────────────────────────────
if start_date is None or end_date is None:
# Fall back to calculating dates based on last_indexed_at
# Default to today (users can manually select future dates if needed)
calculated_end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 30 days ago
if connector.last_indexed_at:
# Convert dates to be comparable (both timezone-naive)
last_indexed_naive = (
connector.last_indexed_at.replace(tzinfo=None)
if connector.last_indexed_at.tzinfo
else connector.last_indexed_at
)
# Allow future dates - use last_indexed_at as start date
calculated_start_date = last_indexed_naive
logger.info(
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
)
else:
calculated_start_date = datetime.now() - timedelta(
days=365
) # Use 365 days as default for calendar events (matches frontend)
calculated_start_date = datetime.now() - timedelta(days=365)
logger.info(
f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date"
)
# Use calculated dates if not provided
start_date_str = (
start_date if start_date else calculated_start_date.strftime("%Y-%m-%d")
)
@ -240,19 +269,14 @@ async def index_google_calendar_events(
end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
)
else:
# Use provided dates (including future dates)
start_date_str = start_date
end_date_str = end_date
# FIX: Ensure end_date is at least 1 day after start_date to avoid
# "start_date must be strictly before end_date" errors when dates are the same
# (e.g., when last_indexed_at is today)
if start_date_str == end_date_str:
logger.info(
f"Start date ({start_date_str}) equals end date ({end_date_str}), "
"adjusting end date to next day to ensure valid date range"
)
# Parse end_date and add 1 day
try:
end_dt = parse_date_flexible(end_date_str)
except ValueError:
@ -264,6 +288,7 @@ async def index_google_calendar_events(
end_date_str = end_dt.strftime("%Y-%m-%d")
logger.info(f"Adjusted end date to {end_date_str}")
# ── Fetch events ──────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
@ -274,27 +299,19 @@ async def index_google_calendar_events(
},
)
# Get events within date range from primary calendar
try:
events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str
)
if error:
# Don't treat "No events found" as an error that should stop indexing
if "No events found" in error:
logger.info(f"No Google Calendar events found: {error}")
logger.info(
"No events found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no events found"
)
await task_logger.log_task_success(
log_entry,
@ -304,7 +321,6 @@ async def index_google_calendar_events(
return 0, 0, None
else:
logger.error(f"Failed to get Google Calendar events: {error}")
# Check if this is an authentication error that requires re-authentication
error_message = error
error_type = "APIError"
if (
@ -329,28 +345,36 @@ async def index_google_calendar_events(
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
documents_indexed = 0
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=event.get("summary", "No Title"),
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
unique_id=event.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"event_id": event.get("id", ""),
"connector_id": connector_id,
"connector_type": "Google Calendar",
},
)
for event in events
if event.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track events that failed processing
duplicate_content_count = (
0 # Track events skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all events, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
events_to_process = [] # List of dicts with document and event data
new_documents_created = False
duplicate_content_count = 0
for event in events:
try:
event_id = event.get("id")
event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
if not event_id:
logger.warning(f"Skipping event with missing ID: {event_summary}")
@ -363,246 +387,55 @@ async def index_google_calendar_events(
documents_skipped += 1
continue
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
description = event.get("description", "")
# Generate unique identifier hash for this Google Calendar event
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
doc = _build_connector_doc(
event,
event_markdown,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(event_markdown, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
event_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_CALENDAR_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Calendar document: {event_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
events_to_process.append(
{
"document": existing_document,
"is_new": False,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
# A document with the same content already exists (likely from Composio connector)
if duplicate:
logger.info(
f"Event {event_summary} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping to avoid duplicate content."
f"Event {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=event_summary,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
document_metadata={
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
events_to_process.append(
{
"document": document,
"is_new": True,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
for item in events_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
logger.error(
f"Error building ConnectorDocument for event: {e!s}", exc_info=True
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"] or "No location",
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])
# Update document to READY with actual content
document.title = item["event_summary"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ─────────────
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(
f"Final commit: Total {documents_indexed} Google Calendar events processed"
)
@ -612,22 +445,18 @@ async def index_google_calendar_events(
"Successfully committed all Google Calendar document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same event was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")

View file

@ -1,12 +1,10 @@
"""
Google Gmail connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
@ -15,21 +13,15 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_gmail_connector import GoogleGmailConnector
from app.db import (
Document,
DocumentStatus,
DocumentType,
SearchSourceConnectorType,
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials,
@ -37,12 +29,9 @@ from app.utils.google_credentials import (
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
@ -51,13 +40,70 @@ ACCEPTED_GMAIL_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
}
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
message: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Gmail API message dict to a ConnectorDocument."""
message_id = message.get("id", "")
thread_id = message.get("threadId", "")
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
metadata = {
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
fallback_summary = (
f"Google Gmail Message: {subject}\n\n"
f"From: {sender}\nDate: {date_str}\n\n"
f"{markdown_content}"
)
return ConnectorDocument(
title=subject,
source_markdown=markdown_content,
unique_id=message_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_gmail_messages(
session: AsyncSession,
connector_id: int,
@ -80,7 +126,7 @@ async def index_google_gmail_messages(
start_date: Start date for filtering messages (YYYY-MM-DD format)
end_date: End date for filtering messages (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
max_messages: Maximum number of messages to fetch (default: 100)
max_messages: Maximum number of messages to fetch (default: 1000)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
@ -88,7 +134,6 @@ async def index_google_gmail_messages(
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="google_gmail_messages_indexing",
source="connector_indexing_task",
@ -103,7 +148,7 @@ async def index_google_gmail_messages(
)
try:
# Accept both native and Composio Gmail connectors
# ── Connector lookup ──────────────────────────────────────────
connector = None
for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct)
@ -117,7 +162,7 @@ async def index_google_gmail_messages(
)
return 0, 0, error_msg
# Build credentials based on connector type
# ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
@ -189,6 +234,7 @@ async def index_google_gmail_messages(
)
return 0, 0, "Google gmail credentials not found in connector config"
# ── Gmail client init ─────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Google gmail client for connector {connector_id}",
@ -199,14 +245,11 @@ async def index_google_gmail_messages(
credentials, session, user_id, connector_id
)
# Calculate date range using last_indexed_at if dates not provided
# This ensures Gmail uses the same date logic as other connectors
# (uses last_indexed_at → now, or 365 days back for first-time indexing)
calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Fetch recent Google gmail messages
# ── Fetch messages ────────────────────────────────────────────
logger.info(
f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}"
@ -218,7 +261,6 @@ async def index_google_gmail_messages(
)
if error:
# Check if this is an authentication error that requires re-authentication
error_message = error
error_type = "APIError"
if (
@ -243,286 +285,103 @@ async def index_google_gmail_messages(
logger.info(f"Found {len(messages)} Google gmail messages to index")
documents_indexed = 0
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
def _gmail_subject(msg: dict) -> str:
for h in msg.get("payload", {}).get("headers", []):
if h.get("name", "").lower() == "subject":
return h.get("value", "No Subject")
return "No Subject"
placeholders = [
PlaceholderInfo(
title=_gmail_subject(msg),
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=msg.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"message_id": msg.get("id", ""),
"connector_id": connector_id,
"connector_type": "Google Gmail",
},
)
for msg in messages
if msg.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track messages that failed processing
duplicate_content_count = (
0 # Track messages skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all messages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
new_documents_created = False
duplicate_content_count = 0
for message in messages:
try:
# Extract message information
message_id = message.get("id", "")
thread_id = message.get("threadId", "")
# Extract headers for subject and sender
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
if not message_id:
logger.warning(f"Skipping message with missing ID: {subject}")
logger.warning("Skipping message with missing ID")
documents_skipped += 1
continue
# Format message to markdown
markdown_content = gmail_connector.format_message_to_markdown(message)
if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {subject}")
logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1
continue
# Generate unique identifier hash for this Gmail message
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id
doc = _build_connector_doc(
message,
markdown_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(markdown_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
message_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GMAIL_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_GMAIL_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Gmail document: {message_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
if duplicate:
logger.info(
f"Gmail message {subject} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
f"Gmail message {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=subject,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
document_metadata={
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
messages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
for item in messages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
logger.error(
f"Error building ConnectorDocument for message: {e!s}",
exc_info=True,
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["subject"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Gmail messages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ─────────────
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
try:
await session.commit()
@ -530,22 +389,18 @@ async def index_google_gmail_messages(
"Successfully committed all Google Gmail document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same message was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -555,7 +410,6 @@ async def index_google_gmail_messages(
total_processed = documents_indexed
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Google Gmail indexing for connector {connector_id}",

View file

@ -1,49 +1,83 @@
"""
Jira connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
"""Jira connector indexer using the unified parallel indexing pipeline."""
import contextlib
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Jira issue dict to a ConnectorDocument."""
issue_id = issue.get("key", "")
issue_identifier = issue.get("key", "")
issue_title = issue.get("id", "")
state = formatted_issue.get("status", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
fallback_summary = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.JIRA_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_jira_issues(
session: AsyncSession,
connector_id: int,
@ -53,26 +87,9 @@ async def index_jira_issues(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
"""
Index Jira issues and comments.
Args:
session: Database session
connector_id: ID of the Jira connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
) -> tuple[int, int, str | None]:
"""Index Jira issues and comments."""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="jira_issues_indexing",
source="connector_indexing_task",
@ -86,7 +103,6 @@ async def index_jira_issues(
)
try:
# Get the connector from the database
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR
)
@ -98,24 +114,15 @@ async def index_jira_issues(
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return 0, f"Connector with ID {connector_id} not found"
return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Jira client with internal refresh capability
# Token refresh will happen automatically when needed
await task_logger.log_task_progress(
log_entry,
f"Initializing Jira client for connector {connector_id}",
{"stage": "client_initialization"},
)
logger.info(f"Initializing Jira client for connector {connector_id}")
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
jira_client = JiraHistoryConnector(session=session, connector_id=connector_id)
# Calculate date range
# Handle "undefined" strings from frontend
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
@ -135,19 +142,14 @@ async def index_jira_issues(
},
)
# Get issues within date range
try:
issues, error = await jira_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error:
logger.info(f"No Jira issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
@ -162,7 +164,8 @@ async def index_jira_issues(
f"No Jira issues found in date range {start_date_str} to {end_date_str}",
{"issues_found": 0},
)
return 0, None
await jira_client.close()
return 0, 0, None
else:
logger.error(f"Failed to get Jira issues: {error}")
await task_logger.log_task_failure(
@ -171,29 +174,51 @@ async def index_jira_issues(
"API Error",
{"error_type": "APIError"},
)
return 0, f"Failed to get Jira issues: {error}"
await jira_client.close()
return 0, 0, f"Failed to get Jira issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Jira API")
except Exception as e:
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
return 0, f"Error fetching Jira issues: {e!s}"
await jira_client.close()
return 0, 0, f"Error fetching Jira issues: {e!s}"
# =======================================================================
# PHASE 1: Analyze all issues, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
documents_indexed = 0
if not issues:
logger.info("No Jira issues found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
await jira_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=f"{issue.get('key', '')}: {issue.get('id', '')}",
document_type=DocumentType.JIRA_CONNECTOR,
unique_id=issue.get("key", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"issue_id": issue.get("key", ""),
"connector_id": connector_id,
"connector_type": "Jira",
},
)
for issue in issues
if issue.get("key") and issue.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues:
try:
issue_id = issue.get("key")
@ -207,10 +232,7 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Format the issue for better readability
formatted_issue = jira_client.format_issue(issue)
# Convert to markdown
issue_content = jira_client.format_issue_to_markdown(formatted_issue)
if not issue_content:
@ -220,53 +242,19 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Generate unique identifier hash for this Jira issue
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
doc = _build_connector_doc(
issue,
formatted_issue,
issue_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(formatted_issue.get("comments", []))
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
session, compute_content_hash(doc)
)
if duplicate_by_content:
@ -279,160 +267,36 @@ async def index_jira_issues(
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.JIRA_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": formatted_issue.get("status", "Unknown"),
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata = {
"issue_key": item["issue_identifier"],
"issue_title": item["issue_title"],
"status": item["formatted_issue"].get("status", "Unknown"),
"priority": item["formatted_issue"].get("priority", "Unknown"),
"comment_count": item["comment_count"],
"document_type": "Jira Issue",
"connector_type": "Jira",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata
)
else:
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full issue content with comments
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["formatted_issue"].get("status", "Unknown"),
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Jira issues processed so far"
)
await session.commit()
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue # Skip this issue and continue with others
documents_skipped += 1
continue
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} Jira issues processed")
try:
await session.commit()
logger.info("Successfully committed all JIRA document changes to database")
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
@ -447,7 +311,6 @@ async def index_jira_issues(
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
@ -455,7 +318,6 @@ async def index_jira_issues(
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed JIRA indexing for connector {connector_id}",
@ -466,17 +328,13 @@ async def index_jira_issues(
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"JIRA indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
# Clean up the connector
await jira_client.close()
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -487,11 +345,10 @@ async def index_jira_issues(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -501,8 +358,7 @@ async def index_jira_issues(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, f"Failed to index JIRA issues: {e!s}"
return 0, 0, f"Failed to index JIRA issues: {e!s}"

View file

@ -1,48 +1,87 @@
"""
Linear connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding with bounded parallel indexing.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Linear issue dict to a ConnectorDocument."""
issue_id = issue.get("id", "")
issue_identifier = issue.get("identifier", "")
issue_title = issue.get("title", "")
state = formatted_issue.get("state", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Linear Issue",
"connector_type": "Linear",
}
fallback_summary = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.LINEAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_linear_issues(
session: AsyncSession,
connector_id: int,
@ -52,26 +91,15 @@ async def index_linear_issues(
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
) -> tuple[int, int, str | None]:
"""
Index Linear issues and comments.
Args:
session: Database session
connector_id: ID of the Linear connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
Tuple of (indexed_count, skipped_count, warning_or_error_message)
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="linear_issues_indexing",
source="connector_indexing_task",
@ -85,7 +113,7 @@ async def index_linear_issues(
)
try:
# Get the connector
# ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Retrieving Linear connector {connector_id} from database",
@ -104,11 +132,11 @@ async def index_linear_issues(
{"error_type": "ConnectorNotFound"},
)
return (
0,
0,
f"Connector with ID {connector_id} not found or is not a Linear connector",
)
# Check if access_token exists (support both new OAuth format and old API key format)
if not connector.config.get("access_token") and not connector.config.get(
"LINEAR_API_KEY"
):
@ -118,26 +146,22 @@ async def index_linear_issues(
"Missing Linear access token",
{"error_type": "MissingToken"},
)
return 0, "Linear access token not found in connector config"
return 0, 0, "Linear access token not found in connector config"
# Initialize Linear client with internal refresh capability
# ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Linear client for connector {connector_id}",
{"stage": "client_initialization"},
)
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
linear_client = LinearConnector(session=session, connector_id=connector_id)
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
@ -154,37 +178,32 @@ async def index_linear_issues(
},
)
# Get issues within date range
# ── Fetch issues ──────────────────────────────────────────────
try:
issues, error = await linear_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
start_date=start_date_str,
end_date=end_date_str,
include_comments=True,
)
if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error:
logger.info(f"No Linear issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None
return 0, 0, None
else:
logger.error(f"Failed to get Linear issues: {error}")
return 0, f"Failed to get Linear issues: {error}"
return 0, 0, f"Failed to get Linear issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Linear API")
except Exception as e:
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
return 0, f"Failed to get Linear issues: {e!s}"
return 0, 0, f"Failed to get Linear issues: {e!s}"
if not issues:
logger.info("No Linear issues found for the specified date range")
@ -193,19 +212,34 @@ async def index_linear_issues(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None # Return None instead of error message when no issues found
return 0, 0, None
# Track the number of documents indexed
documents_indexed = 0
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=f"{issue.get('identifier', '')}: {issue.get('title', '')}",
document_type=DocumentType.LINEAR_CONNECTOR,
unique_id=issue.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"issue_id": issue.get("id", ""),
"issue_identifier": issue.get("identifier", ""),
"connector_id": connector_id,
"connector_type": "Linear",
},
)
for issue in issues
if issue.get("id") and issue.get("title")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0 # Track issues that failed processing
skipped_issues = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
duplicate_content_count = 0
await task_logger.log_task_progress(
log_entry,
@ -213,13 +247,6 @@ async def index_linear_issues(
{"stage": "process_issues", "total_issues": len(issues)},
)
# =======================================================================
# PHASE 1: Analyze all issues, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues:
try:
issue_id = issue.get("id", "")
@ -230,243 +257,70 @@ async def index_linear_issues(
logger.warning(
f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}"
)
skipped_issues.append(
f"{issue_identifier or 'Unknown'} (missing data)"
)
documents_skipped += 1
continue
# Format the issue first to get well-structured data
formatted_issue = linear_client.format_issue(issue)
# Convert issue to markdown format
issue_content = linear_client.format_issue_to_markdown(formatted_issue)
if not issue_content:
logger.warning(
f"Skipping issue with no content: {issue_identifier} - {issue_title}"
)
skipped_issues.append(f"{issue_identifier} (no content)")
documents_skipped += 1
continue
# Generate unique identifier hash for this Linear issue
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id
)
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
state = formatted_issue.get("state", "Unknown")
description = formatted_issue.get("description", "")
comment_count = len(formatted_issue.get("comments", []))
priority = formatted_issue.get("priority", "Unknown")
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
logger.info(
f"Document for Linear issue {issue_identifier} unchanged. Skipping."
)
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Linear issue {issue_identifier} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.LINEAR_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
doc = _build_connector_doc(
issue,
formatted_issue,
issue_content,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
except Exception as e:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"issue_id": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"priority": item["priority"],
"comment_count": item["comment_count"],
"document_type": "Linear Issue",
"connector_type": "Linear",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata_for_summary
with session.no_autoflush:
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
else:
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
if duplicate:
logger.info(
f"Committing batch: {documents_indexed} Linear issues processed so far"
f"Linear issue {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
await session.commit()
duplicate_content_count += 1
documents_skipped += 1
continue
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}",
f"Error building ConnectorDocument for issue: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
skipped_issues.append(
f"{item.get('issue_identifier', 'Unknown')} (processing error)"
)
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ────────────
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed")
try:
await session.commit()
@ -474,27 +328,25 @@ async def index_linear_issues(
"Successfully committed all Linear document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same issue was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
else:
raise
# Build warning message if there were issues
warning_parts = []
warning_parts: list[str] = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Linear indexing for connector {connector_id}",
@ -503,7 +355,7 @@ async def index_linear_issues(
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"skipped_issues_count": len(skipped_issues),
"duplicate_content_count": duplicate_content_count,
},
)
@ -511,7 +363,7 @@ async def index_linear_issues(
f"Linear indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed"
)
return documents_indexed, warning_message
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -522,7 +374,7 @@ async def index_linear_issues(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -532,4 +384,4 @@ async def index_linear_issues(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
return 0, f"Failed to index Linear issues: {e!s}"
return 0, 0, f"Failed to index Linear issues: {e!s}"

View file

@ -1,12 +1,10 @@
"""
Notion connector indexer.
Implements real-time document status updates using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (pending processing ready/failed)
Uses the shared IndexingPipelineService for document deduplication,
summarization, chunking, and embedding with bounded parallel indexing.
"""
import time
from collections.abc import Awaitable, Callable
from datetime import datetime
@ -14,42 +12,67 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.notion_history import NotionHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.notion_utils import process_blocks
from .base import (
build_document_metadata_string,
calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type alias for retry callback
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]]
# Type alias for heartbeat callback
# Signature: async callback(indexed_count) -> None
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Notion page dict to a ConnectorDocument."""
page_id = page.get("page_id", "")
page_title = page.get("title", f"Untitled page ({page_id})")
metadata = {
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
"document_type": "Notion Page",
"connector_type": "Notion",
}
fallback_summary = f"Notion Page: {page_title}\n\n{markdown_content}"
return ConnectorDocument(
title=page_title,
source_markdown=markdown_content,
unique_id=page_id,
document_type=DocumentType.NOTION_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_notion_pages(
session: AsyncSession,
connector_id: int,
@ -60,30 +83,15 @@ async def index_notion_pages(
update_last_indexed: bool = True,
on_retry_callback: RetryCallbackType | None = None,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
) -> tuple[int, int, str | None]:
"""
Index Notion pages from all accessible pages.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_retry_callback: Optional callback for retry progress notifications.
Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds)
retry_reason is one of: 'rate_limit', 'server_error', 'timeout'
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Called periodically with (indexed_count) to prevent task appearing stuck.
Returns:
Tuple containing (number of documents indexed, error message or None)
Tuple of (indexed_count, skipped_count, warning_or_error_message)
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="notion_pages_indexing",
source="connector_indexing_task",
@ -97,7 +105,7 @@ async def index_notion_pages(
)
try:
# Get the connector
# ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Retrieving Notion connector {connector_id} from database",
@ -116,11 +124,11 @@ async def index_notion_pages(
{"error_type": "ConnectorNotFound"},
)
return (
0,
0,
f"Connector with ID {connector_id} not found or is not a Notion connector",
)
# Check if access_token exists (support both new OAuth format and old integration token format)
if not connector.config.get("access_token") and not connector.config.get(
"NOTION_INTEGRATION_TOKEN"
):
@ -130,9 +138,9 @@ async def index_notion_pages(
"Missing Notion access token",
{"error_type": "MissingToken"},
)
return 0, "Notion access token not found in connector config"
return 0, 0, "Notion access token not found in connector config"
# Initialize Notion client with internal refresh capability
# ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress(
log_entry,
f"Initializing Notion client for connector {connector_id}",
@ -141,18 +149,15 @@ async def index_notion_pages(
logger.info(f"Initializing Notion client for connector {connector_id}")
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
# Calculate date range using the shared utility function
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
# Convert YYYY-MM-DD to ISO format for Notion API
start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
@ -160,13 +165,10 @@ async def index_notion_pages(
"%Y-%m-%dT%H:%M:%SZ"
)
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
notion_client = NotionHistoryConnector(
session=session, connector_id=connector_id
)
# Set retry callback if provided (for user notifications during rate limits)
if on_retry_callback:
notion_client.set_retry_callback(on_retry_callback)
@ -182,21 +184,19 @@ async def index_notion_pages(
},
)
# Get all pages
# ── Fetch pages ───────────────────────────────────────────────
try:
pages = await notion_client.get_all_pages(
start_date=start_date_iso, end_date=end_date_iso
)
logger.info(f"Found {len(pages)} Notion pages")
# Get count of pages that had unsupported content skipped
pages_with_skipped_content = notion_client.get_skipped_content_count()
if pages_with_skipped_content > 0:
logger.info(
f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)"
)
# Check if using legacy integration token and log warning
if notion_client.is_using_legacy_token():
logger.warning(
f"Connector {connector_id} is using legacy integration token. "
@ -204,8 +204,6 @@ async def index_notion_pages(
)
except Exception as e:
error_str = str(e)
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
# These are known Notion API limitations and should be logged as warnings, not errors
unsupported_block_errors = [
"transcription is not supported",
"ai_block is not supported",
@ -216,7 +214,6 @@ async def index_notion_pages(
)
if is_unsupported_block_error:
# Log as warning since this is a known Notion API limitation
logger.warning(
f"Notion API limitation for connector {connector_id}: {error_str}. "
"This is a known issue with Notion AI blocks (transcription, ai_block) "
@ -229,7 +226,6 @@ async def index_notion_pages(
{"error_type": "UnsupportedBlockType", "is_known_limitation": True},
)
else:
# Log as error for other failures
logger.error(
f"Error fetching Notion pages for connector {connector_id}: {error_str}",
exc_info=True,
@ -242,7 +238,7 @@ async def index_notion_pages(
)
await notion_client.close()
return 0, f"Failed to get Notion pages: {e!s}"
return 0, 0, f"Failed to get Notion pages: {e!s}"
if not pages:
await task_logger.log_task_success(
@ -252,21 +248,36 @@ async def index_notion_pages(
{"pages_found": 0},
)
logger.info("No Notion pages found to index")
# CRITICAL: Update timestamp even when no pages found so Zero syncs
await update_connector_last_indexed(session, connector, update_last_indexed)
await session.commit()
await notion_client.close()
return 0, None # Success with 0 pages, not an error
return 0, 0, None
# Track the number of documents indexed
documents_indexed = 0
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=page.get("title", f"Untitled page ({page.get('page_id', '')})"),
document_type=DocumentType.NOTION_CONNECTOR,
unique_id=page.get("page_id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"page_id": page.get("page_id", ""),
"connector_id": connector_id,
"connector_type": "Notion",
},
)
for page in pages
if page.get("page_id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
skipped_pages = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
await task_logger.log_task_progress(
log_entry,
@ -274,13 +285,6 @@ async def index_notion_pages(
{"stage": "process_pages", "total_pages": len(pages)},
)
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
for page in pages:
try:
page_id = page.get("page_id")
@ -293,225 +297,67 @@ async def index_notion_pages(
if not page_content:
logger.info(f"No content found in page {page_title}. Skipping.")
skipped_pages.append(f"{page_title} (no content)")
documents_skipped += 1
continue
# Convert page content to markdown format
markdown_content = f"# Notion Page: {page_title}\n\n"
markdown_content += process_blocks(page_content)
# Format document metadata
metadata_sections = [
("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
markdown_content,
"TEXT_END",
],
),
]
# Build the document string
combined_document_string = build_document_metadata_string(
metadata_sections
)
# Generate unique identifier hash for this Notion page
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTION_CONNECTOR, page_id, search_space_id
)
# Generate content hash
content_hash = generate_content_hash(
combined_document_string, search_space_id
)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
pages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
}
)
if not markdown_content.strip():
logger.warning(f"Skipping page with empty markdown: {page_title}")
documents_skipped += 1
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
doc = _build_connector_doc(
page,
markdown_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
if duplicate_by_content:
with session.no_autoflush:
duplicate = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate:
logger.info(
f"Notion page {page_title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
f"Notion page {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.NOTION_CONNECTOR,
document_metadata={
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
}
)
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
for item in pages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
logger.error(
f"Error building ConnectorDocument for page: {e!s}",
exc_info=True,
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"document_type": "Notion Page",
"connector_type": "Notion",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Notion pages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Notion page: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
skipped_pages.append(f"{item['page_title']} (processing error)")
documents_failed += 1
documents_skipped += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# ── Pipeline: migrate legacy docs + parallel index ────────────
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
total_processed = documents_indexed
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} documents processed")
try:
await session.commit()
@ -519,59 +365,53 @@ async def index_notion_pages(
"Successfully committed all Notion document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same page was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Get final count of pages with skipped Notion AI content
# ── Build warning / notification message ──────────────────────
pages_with_skipped_ai_content = notion_client.get_skipped_content_count()
# Build warning message if there were issues
warning_parts = []
warning_parts: list[str] = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Prepare result message with user-friendly notification about skipped content
result_message = None
if skipped_pages:
result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
else:
result_message = f"Processed {total_processed} pages."
# Add user-friendly message about skipped Notion AI content
notification_parts: list[str] = []
if pages_with_skipped_ai_content > 0:
result_message += (
" Audio transcriptions and AI summaries from Notion aren't accessible "
"via their API - all other content was saved."
notification_parts.append(
"Some Notion AI content couldn't be synced (API limitation)"
)
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
if warning_parts:
notification_parts.append(", ".join(warning_parts))
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
)
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully completed Notion indexing for connector {connector_id}",
{
"pages_processed": total_processed,
"pages_processed": documents_indexed,
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
"skipped_pages_count": len(skipped_pages),
"pages_with_skipped_ai_content": pages_with_skipped_ai_content,
"result_message": result_message,
},
)
@ -581,35 +421,9 @@ async def index_notion_pages(
f"({duplicate_content_count} duplicate content)"
)
# Clean up the async client
await notion_client.close()
# Build user-friendly notification messages
# This will be shown in the notification to inform users
notification_parts = []
if pages_with_skipped_ai_content > 0:
notification_parts.append(
"Some Notion AI content couldn't be synced (API limitation)"
)
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
# Include warning message if there were issues
if warning_message:
notification_parts.append(warning_message)
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
)
return (
total_processed,
user_notification_message,
)
return documents_indexed, documents_skipped, user_notification_message
except SQLAlchemyError as db_error:
await session.rollback()
@ -622,10 +436,9 @@ async def index_notion_pages(
logger.error(
f"Database error during Notion indexing: {db_error!s}", exc_info=True
)
# Clean up the async client in case of error
if "notion_client" in locals():
await notion_client.close()
return 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -635,7 +448,6 @@ async def index_notion_pages(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
# Clean up the async client in case of error
if "notion_client" in locals():
await notion_client.close()
return 0, f"Failed to index Notion pages: {e!s}"
return 0, 0, f"Failed to index Notion pages: {e!s}"

View file

@ -1,5 +1,6 @@
import hashlib
import logging
import threading
import warnings
import numpy as np
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__)
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
# access from multiple threads causes "RuntimeError: Already borrowed".
# This reentrant lock serialises tokenizer + embedding model access so that
# asyncio.to_thread calls from index_batch_parallel don't collide.
_embedding_lock = threading.RLock()
def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model.
@ -36,23 +43,25 @@ def truncate_for_embedding(text: str) -> str:
if len(text) // 3 <= max_tokens:
return text
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
with _embedding_lock:
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window."""
return config.embedding_model_instance.embed(truncate_for_embedding(text))
with _embedding_lock:
return config.embedding_model_instance.embed(truncate_for_embedding(text))
def embed_texts(texts: list[str]) -> list[np.ndarray]:
@ -66,10 +75,11 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
"""
if not texts:
return []
truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated]
return config.embedding_model_instance.embed_batch(truncated)
with _embedding_lock:
truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated]
return config.embedding_model_instance.embed_batch(truncated)
def get_model_context_window(model_name: str) -> int:

View file

@ -46,7 +46,6 @@ dependencies = [
"redis>=5.2.1",
"firecrawl-py>=4.9.0",
"boto3>=1.35.0",
"langchain-community>=0.3.31",
"litellm>=1.80.10",
"langchain-litellm>=0.3.5",
"fake-useragent>=2.2.0",
@ -60,19 +59,21 @@ dependencies = [
"sse-starlette>=3.1.1,<3.1.2",
"gitingest>=0.3.1",
"composio>=0.10.9",
"langchain>=1.2.6",
"langgraph>=1.0.5",
"unstructured[all-docs]>=0.18.31",
"unstructured-client>=0.42.3",
"langchain-unstructured>=1.0.1",
"slowapi>=0.1.9",
"pypandoc_binary>=1.16.2",
"typst>=0.14.0",
"deepagents>=0.4.3",
"daytona>=0.146.0",
"langchain-daytona>=0.0.2",
"pypandoc>=1.16.2",
"notion-markdown>=0.7.0",
"fractional-indexing>=0.1.3",
"langchain>=1.2.13",
"langgraph>=1.1.3",
"langchain-community>=0.4.1",
"deepagents>=0.4.12",
]
[dependency-groups]

View file

@ -0,0 +1,119 @@
"""Integration tests: Calendar indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _cal_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"Event {unique_id}",
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Calendar: Event {unique_id}",
metadata={
"event_id": unique_id,
"start_time": "2025-01-15T10:00:00",
"end_time": "2025-01-15T11:00:00",
"document_type": "Google Calendar Event",
},
)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Calendar ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _cal_doc(
unique_id="evt-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Calendar doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
evt_id = "evt-legacy-cal"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
legacy_doc = Document(
title="Old Calendar Event",
document_type=DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old event",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _cal_doc(
unique_id=evt_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
assert row.unique_identifier_hash == native_hash

View file

@ -0,0 +1,185 @@
"""Integration tests: Drive indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _drive_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"File {unique_id}.pdf",
source_markdown=f"## Document Content\n\nText from file {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_DRIVE_FILE,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"File: {unique_id}.pdf",
metadata={
"google_drive_file_id": unique_id,
"google_drive_file_name": f"{unique_id}.pdf",
"document_type": "Google Drive File",
},
)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Drive ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _drive_doc(
unique_id="file-abc",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Drive doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
file_id = "file-legacy-drive"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, space_id
)
legacy_doc = Document(
title="Old Drive File",
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
content="old file summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old file content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _drive_doc(
unique_id=file_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
assert row.unique_identifier_hash == native_hash
async def test_should_skip_file_skips_failed_document(
db_session,
db_search_space,
db_user,
):
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
import importlib
import sys
import types
pkg = "app.tasks.connector_indexers"
stub = pkg not in sys.modules
if stub:
mod = types.ModuleType(pkg)
mod.__path__ = ["app/tasks/connector_indexers"]
mod.__package__ = pkg
sys.modules[pkg] = mod
try:
gdm = importlib.import_module(
"app.tasks.connector_indexers.google_drive_indexer"
)
_should_skip_file = gdm._should_skip_file
finally:
if stub:
sys.modules.pop(pkg, None)
space_id = db_search_space.id
file_id = "file-failed-drive"
md5 = "abc123deadbeef"
doc_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
failed_doc = Document(
title="Failed File.pdf",
document_type=DocumentType.GOOGLE_DRIVE_FILE,
content="LLM rate limit exceeded",
content_hash=f"ch-{doc_hash[:12]}",
unique_identifier_hash=doc_hash,
source_markdown="## Real content",
search_space_id=space_id,
created_by_id=str(db_user.id),
embedding=[0.1] * _EMBEDDING_DIM,
status=DocumentStatus.failed("LLM rate limit exceeded"),
document_metadata={
"google_drive_file_id": file_id,
"google_drive_file_name": "Failed File.pdf",
"md5_checksum": md5,
},
)
db_session.add(failed_doc)
await db_session.flush()
incoming_file = {
"id": file_id,
"name": "Failed File.pdf",
"mimeType": "application/pdf",
"md5Checksum": md5,
}
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
assert should_skip, "FAILED documents must be skipped during automatic sync"
assert "failed" in msg.lower()

View file

@ -0,0 +1,121 @@
"""Integration tests: Gmail indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _gmail_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
return ConnectorDocument(
title=f"Subject for {unique_id}",
source_markdown=f"## Email\n\nBody of {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Gmail: Subject for {unique_id}",
metadata={
"message_id": unique_id,
"from": "sender@example.com",
"document_type": "Gmail Message",
},
)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Gmail ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _gmail_doc(
unique_id="msg-pipeline-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.source_markdown == doc.source_markdown
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_legacy_doc_migrated_then_reused(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Gmail doc is migrated then reused by the pipeline."""
space_id = db_search_space.id
user_id = str(db_user.id)
msg_id = "msg-legacy-gmail"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, msg_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _gmail_doc(
unique_id=msg_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
prepared = await service.prepare_for_indexing([connector_doc])
assert len(prepared) == 1
assert prepared[0].id == original_id
assert prepared[0].document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, msg_id, space_id
)
assert prepared[0].unique_identifier_hash == native_hash

View file

@ -0,0 +1,59 @@
"""Integration tests for IndexingPipelineService.index_batch()."""
import pytest
from sqlalchemy import select
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_creates_ready_documents(
db_session, db_search_space, make_connector_document, mocker
):
"""index_batch prepares and indexes a batch, resulting in READY documents."""
space_id = db_search_space.id
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-1",
search_space_id=space_id,
source_markdown="## Email 1\n\nBody",
),
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-2",
search_space_id=space_id,
source_markdown="## Email 2\n\nDifferent body",
),
]
service = IndexingPipelineService(session=db_session)
results = await service.index_batch(docs, llm=mocker.Mock())
assert len(results) == 2
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
rows = result.scalars().all()
assert len(rows) == 2
for row in rows:
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.content is not None
assert row.embedding is not None
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_empty_returns_empty(db_session, mocker):
"""index_batch with empty input returns an empty list."""
service = IndexingPipelineService(session=db_session)
results = await service.index_batch([], llm=mocker.Mock())
assert results == []

View file

@ -0,0 +1,92 @@
"""Integration tests for IndexingPipelineService.migrate_legacy_docs()."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
async def test_legacy_composio_gmail_doc_migrated_in_db(
db_session, db_search_space, db_user, make_connector_document
):
"""A Composio Gmail doc in the DB gets its hash and type updated to native."""
space_id = db_search_space.id
user_id = str(db_user.id)
unique_id = "msg-legacy-123"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, unique_id, space_id
)
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, unique_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="legacy content",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
doc_id = legacy_doc.id
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=unique_id,
search_space_id=space_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == doc_id))
reloaded = result.scalars().first()
assert reloaded.unique_identifier_hash == native_hash
assert reloaded.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
async def test_no_legacy_doc_is_noop(
db_session, db_search_space, make_connector_document
):
"""When no legacy document exists, migrate_legacy_docs does nothing."""
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
unique_id="evt-no-legacy",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert result.scalars().all() == []
async def test_non_google_type_is_skipped(
db_session, db_search_space, make_connector_document
):
"""migrate_legacy_docs skips ConnectorDocuments that are not Google types."""
connector_doc = make_connector_document(
document_type=DocumentType.CLICKUP_CONNECTOR,
unique_id="task-1",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])

View file

@ -0,0 +1,106 @@
"""Shared fixtures for retriever integration tests."""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Chunk, Document, DocumentType, SearchSpace, User
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
def _make_document(
*,
title: str,
document_type: DocumentType,
content: str,
search_space_id: int,
created_by_id: str,
) -> Document:
uid = uuid.uuid4().hex[:12]
return Document(
title=title,
document_type=document_type,
content=content,
content_hash=f"content-{uid}",
unique_identifier_hash=f"uid-{uid}",
source_markdown=content,
search_space_id=search_space_id,
created_by_id=created_by_id,
embedding=DUMMY_EMBEDDING,
updated_at=datetime.now(UTC),
status={"state": "ready"},
)
def _make_chunk(*, content: str, document_id: int) -> Chunk:
return Chunk(
content=content,
document_id=document_id,
embedding=DUMMY_EMBEDDING,
)
@pytest_asyncio.fixture
async def seed_large_doc(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
"""Insert a document with 35 chunks (more than _MAX_FETCH_CHUNKS_PER_DOC=20).
Also inserts a small 3-chunk document for diversity testing.
Returns a dict with ``large_doc``, ``small_doc``, ``search_space``, ``user``,
and ``large_chunk_ids`` (all 35 chunk IDs).
"""
user_id = str(db_user.id)
space_id = db_search_space.id
large_doc = _make_document(
title="Large PDF Document",
document_type=DocumentType.FILE,
content="large document about quarterly performance reviews and budgets",
search_space_id=space_id,
created_by_id=user_id,
)
small_doc = _make_document(
title="Small Note",
document_type=DocumentType.NOTE,
content="quarterly performance review summary note",
search_space_id=space_id,
created_by_id=user_id,
)
db_session.add_all([large_doc, small_doc])
await db_session.flush()
large_chunks = []
for i in range(35):
chunk = _make_chunk(
content=f"chunk {i} about quarterly performance review section {i}",
document_id=large_doc.id,
)
large_chunks.append(chunk)
small_chunks = [
_make_chunk(
content="quarterly performance review summary note content",
document_id=small_doc.id,
),
]
db_session.add_all(large_chunks + small_chunks)
await db_session.flush()
return {
"large_doc": large_doc,
"small_doc": small_doc,
"large_chunk_ids": [c.id for c in large_chunks],
"small_chunk_ids": [c.id for c in small_chunks],
"search_space": db_search_space,
"user": db_user,
}

View file

@ -0,0 +1,116 @@
"""Integration tests for optimized ChucksHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit, column pruning,
and doc metadata caching from RRF results.
"""
import pytest
from app.retriever.chunks_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
ChucksHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated_from_rrf(db_session, seed_large_doc):
"""Document metadata (title, type, etc.) should be present even without joinedload."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_matched_chunk_ids_tracked(db_session, seed_large_doc):
"""matched_chunk_ids should contain the chunk IDs that appeared in the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
matched = result.get("matched_chunk_ids", [])
chunk_ids_in_result = {c["chunk_id"] for c in result["chunks"]}
for mid in matched:
assert mid in chunk_ids_in_result, (
f"matched_chunk_id {mid} not found in chunks"
)
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID (original order)."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
async def test_score_is_positive_float(db_session, seed_large_doc):
"""Each result should have a positive float score from RRF."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
assert isinstance(result["score"], float)
assert result["score"] > 0

View file

@ -0,0 +1,76 @@
"""Integration tests for optimized DocumentHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit and column pruning.
"""
import pytest
from app.retriever.documents_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
DocumentHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated(db_session, seed_large_doc):
"""Document metadata should be present from the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"

View file

@ -0,0 +1,34 @@
"""Pre-register the connector_indexers package to bypass a circular import
in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers).
This lets tests import individual indexer modules (e.g.
``google_drive_indexer``) without triggering the full package init.
"""
import sys
import types
from pathlib import Path
_BACKEND = Path(__file__).resolve().parents[3]
def _stub_package(dotted: str, fs_dir: Path) -> None:
if dotted not in sys.modules:
mod = types.ModuleType(dotted)
mod.__path__ = [str(fs_dir)]
mod.__package__ = dotted
sys.modules[dotted] = mod
parts = dotted.split(".")
if len(parts) > 1:
parent_dotted = ".".join(parts[:-1])
parent = sys.modules.get(parent_dotted)
if parent is not None:
setattr(parent, parts[-1], sys.modules[dotted])
_stub_package("app.tasks", _BACKEND / "app" / "tasks")
_stub_package(
"app.tasks.connector_indexers",
_BACKEND / "app" / "tasks" / "connector_indexers",
)

View file

@ -0,0 +1,385 @@
"""Tests for Confluence indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.confluence_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.confluence_indexer import (
_build_connector_doc,
index_confluence_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(
page_id: str = "p1",
title: str = "Home",
space_id: str = "S1",
body: str = "<p>Hello</p>",
comments=None,
):
return {
"id": page_id,
"title": title,
"spaceId": space_id,
"body": {"storage": {"value": body}},
"comments": comments or [],
}
def _to_markdown(page: dict) -> str:
page_title = page.get("title", "")
page_content = page.get("body", {}).get("storage", {}).get("value", "")
comments = page.get("comments", [])
comments_content = ""
if comments:
comments_content = "\n\n## Comments\n\n"
for comment in comments:
comment_body = comment.get("body", {}).get("storage", {}).get("value", "")
comment_author = comment.get("version", {}).get("authorId", "Unknown")
comment_date = comment.get("version", {}).get("createdAt", "")
comments_content += (
f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
)
return f"# {page_title}\n\n{page_content}{comments_content}"
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
page = _make_page(
page_id="abc-123",
title="Engineering Handbook",
space_id="ENG",
comments=[{"id": "c1"}],
)
markdown = _to_markdown(page)
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "Engineering Handbook"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.CONFLUENCE_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["page_title"] == "Engineering Handbook"
assert doc.metadata["space_id"] == "ENG"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Confluence Page"
assert doc.metadata["connector_type"] == "Confluence"
assert doc.fallback_summary is not None
assert "Engineering Handbook" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_page(),
_to_markdown(_make_page()),
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_confluence_client(pages=None, error=None):
client = MagicMock()
client.get_pages_by_date_range = AsyncMock(
return_value=(pages if pages is not None else [], error),
)
client.close = AsyncMock()
return client
@pytest.fixture
def confluence_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
confluence_client = _mock_confluence_client(pages=[_make_page()])
monkeypatch.setattr(
_mod,
"ConfluenceHistoryConnector",
MagicMock(return_value=confluence_client),
)
monkeypatch.setattr(
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"confluence_client": confluence_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_confluence_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(confluence_mocks):
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_called_once()
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.CONFLUENCE_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(confluence_mocks):
await _run_index(confluence_mocks)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(confluence_mocks):
await _run_index(confluence_mocks)
confluence_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (missing id/title/content)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="", title="Missing id"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_missing_title_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="p2", title=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_no_content_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid", body="<p>ok</p>"),
_make_page(page_id="p2", title="Empty", body=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(confluence_mocks, monkeypatch):
pages = [
_make_page(page_id="p1", title="One"),
_make_page(page_id="p2", title="Two"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(confluence_mocks):
heartbeat_cb = AsyncMock()
await _run_index(confluence_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty pages and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
None,
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_not_called()
async def test_no_pages_error_message_returns_success_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"No pages found in date range",
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(confluence_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Confluence pages" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(confluence_mocks):
confluence_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(confluence_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,740 @@
"""Tests for parallel download + indexing in the Google Drive indexer."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.tasks.connector_indexers.google_drive_indexer import (
_download_files_parallel,
_index_full_scan,
_index_selected_files,
_index_with_delta_sync,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
return {"id": file_id, "name": name, "mimeType": mime}
def _mock_extract_ok(file_id: str, file_name: str):
"""Return a successful (markdown, metadata, None) tuple."""
return (
f"# Content of {file_name}",
{"google_drive_file_id": file_id, "google_drive_file_name": file_name},
None,
)
@pytest.fixture
def mock_drive_client():
return MagicMock()
@pytest.fixture
def patch_extract(monkeypatch):
"""Provide a helper to set the download_and_extract_content mock."""
def _patch(side_effect=None, return_value=None):
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
mock,
)
return mock
return _patch
async def test_single_file_returns_one_connector_document(
mock_drive_client,
patch_extract,
):
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
docs, failed = await _download_files_parallel(
mock_drive_client,
[_make_file_dict("f1", "test.txt")],
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 0
assert docs[0].title == "test.txt"
assert docs[0].unique_id == "f1"
async def test_multiple_files_all_produce_documents(
mock_drive_client,
patch_extract,
):
"""All files are downloaded and converted to ConnectorDocuments."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 3
assert failed == 0
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
async def test_one_download_exception_does_not_block_others(
mock_drive_client,
patch_extract,
):
"""A RuntimeError in one download still lets the other files succeed."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "file0.txt"),
RuntimeError("network timeout"),
_mock_extract_ok("f2", "file2.txt"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 2
assert failed == 1
assert {d.unique_id for d in docs} == {"f0", "f2"}
async def test_etl_error_counts_as_download_failure(
mock_drive_client,
patch_extract,
):
"""download_and_extract_content returning an error is counted as failed."""
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "good.txt"),
(None, {}, "ETL failed"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 1
async def test_concurrency_bounded_by_semaphore(
mock_drive_client,
monkeypatch,
):
"""Peak concurrent downloads never exceeds max_concurrency."""
lock = asyncio.Lock()
active = 0
peak = 0
async def _slow_extract(client, file):
nonlocal active, peak
async with lock:
active += 1
peak = max(peak, active)
await asyncio.sleep(0.05)
async with lock:
active -= 1
fid = file["id"]
return _mock_extract_ok(fid, file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
max_concurrency=2,
)
assert len(docs) == 6
assert failed == 0
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
async def test_heartbeat_fires_during_parallel_downloads(
mock_drive_client,
monkeypatch,
):
"""on_heartbeat is called at least once when downloads take time."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
async def _slow_extract(client, file):
await asyncio.sleep(0.05)
return _mock_extract_ok(file["id"], file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
heartbeat_calls: list[int] = []
async def _on_heartbeat(count: int):
heartbeat_calls.append(count)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
on_heartbeat=_on_heartbeat,
)
assert len(docs) == 3
assert failed == 0
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
# ---------------------------------------------------------------------------
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
# ---------------------------------------------------------------------------
def _folder_dict(file_id: str, name: str) -> dict:
return {
"id": file_id,
"name": name,
"mimeType": "application/vnd.google-apps.folder",
}
@pytest.fixture
def full_scan_mocks(mock_drive_client, monkeypatch):
"""Wire up all mocks needed to call _index_full_scan in isolation."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
mock_connector = MagicMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
mock_log_entry = MagicMock()
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_mock = AsyncMock(return_value=([], 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 0, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"connector": mock_connector,
"task_logger": mock_task_logger,
"log_entry": mock_log_entry,
"skip_results": skip_results,
"download_mock": download_mock,
"batch_mock": batch_mock,
"pipeline_mock": pipeline_mock,
}
async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False):
return await _index_full_scan(
mocks["drive_client"],
mocks["session"],
mocks["connector"],
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"My Folder",
mocks["task_logger"],
mocks["log_entry"],
max_files,
include_subfolders=include_subfolders,
enable_summary=True,
)
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
"""Full scan collects files serially, downloads and indexes in parallel,
and returns correct (indexed, skipped) with renames counted as indexed."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [
_folder_dict("folder1", "SubFolder"),
_make_file_dict("skip1", "unchanged.txt"),
_make_file_dict("rename1", "renamed.txt"),
_make_file_dict("new1", "new1.txt"),
_make_file_dict("new2", "new2.txt"),
]
monkeypatch.setattr(
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (
True,
"File renamed: 'old''renamed.txt'",
)
mock_docs = [MagicMock(), MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
indexed, skipped = await _run_full_scan(full_scan_mocks)
assert indexed == 3 # 1 renamed + 2 from batch
assert skipped == 1 # 1 unchanged
full_scan_mocks["download_mock"].assert_called_once()
call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"new1", "new2"}
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
"""Only max_files non-folder files are processed; the rest are ignored."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
monkeypatch.setattr(
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["download_mock"].return_value = ([], 0)
full_scan_mocks["batch_mock"].return_value = ([], 0, 0)
await _run_full_scan(full_scan_mocks, max_files=3)
download_call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(download_call_files) == 3
async def test_full_scan_uses_max_concurrency_3_for_indexing(
full_scan_mocks,
monkeypatch,
):
"""index_batch_parallel is called with max_concurrency=3."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict("f1", "file1.txt")]
monkeypatch.setattr(
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
mock_docs = [MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 1, 0)
await _run_full_scan(full_scan_mocks)
call_kwargs = full_scan_mocks["batch_mock"].call_args
assert call_kwargs[1].get("max_concurrency") == 3 or (
len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3
)
# ---------------------------------------------------------------------------
# Slice 7 -- _index_with_delta_sync three-phase pipeline
# ---------------------------------------------------------------------------
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"""Removed/trashed changes call _remove_document; the rest go through
_download_files_parallel and index_batch_parallel."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
changes = [
{"fileId": "del1", "removed": True},
{"fileId": "del2", "file": {"id": "del2", "trashed": True}},
{"fileId": "trash1", "file": {"id": "trash1", "trashed": True}},
{"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")},
{"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")},
]
monkeypatch.setattr(
_mod,
"fetch_all_changes",
AsyncMock(return_value=(changes, "new-token", None)),
)
change_types = {
"del1": "removed",
"del2": "removed",
"trash1": "trashed",
"mod1": "modified",
"mod2": "modified",
}
monkeypatch.setattr(
_mod,
"categorize_change",
lambda change: change_types[change["fileId"]],
)
remove_calls: list[str] = []
async def _fake_remove(session, file_id, search_space_id):
remove_calls.append(file_id)
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod,
"_should_skip_file",
AsyncMock(return_value=(False, None)),
)
mock_docs = [MagicMock(), MagicMock()]
download_mock = AsyncMock(return_value=(mock_docs, 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 2, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
mock_session = AsyncMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped = await _index_with_delta_sync(
MagicMock(),
mock_session,
MagicMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"start-token-abc",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert sorted(remove_calls) == ["del1", "del2", "trash1"]
download_mock.assert_called_once()
downloaded_files = download_mock.call_args[0][1]
assert len(downloaded_files) == 2
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
assert indexed == 2
assert skipped == 0
# ---------------------------------------------------------------------------
# _index_selected_files -- parallel indexing of user-selected files
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_drive_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
async def _fake_get_file(client, file_id):
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
pipeline_mock = MagicMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"get_file_results": get_file_results,
"skip_results": skip_results,
"download_and_index_mock": download_and_index_mock,
}
async def _run_selected(mocks, file_ids):
return await _index_selected_files(
mocks["drive_client"],
mocks["session"],
file_ids,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
async def test_selected_files_single_file_indexed(selected_files_mocks):
"""Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "report.pdf"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("f1", "report.pdf")],
)
assert indexed == 1
assert skipped == 0
assert errors == []
selected_files_mocks["download_and_index_mock"].assert_called_once()
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
"""get_file_by_id failing for one file collects an error; others still indexed."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "first.txt"),
None,
)
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["f3"] = (
_make_file_dict("f3", "third.txt"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
)
assert indexed == 2
assert skipped == 0
assert len(errors) == 1
assert "mid.txt" in errors[0]
assert "HTTP 404" in errors[0]
async def test_selected_files_skip_rename_counting(selected_files_mocks):
"""Unchanged files are skipped, renames counted as indexed,
and only new files are sent to _download_and_index."""
for fid, fname in [
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
]:
selected_files_mocks["get_file_results"][fid] = (
_make_file_dict(fid, fname),
None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (
True,
"File renamed: 'old' \u2192 'renamed.txt'",
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
mock.assert_called_once()
call_files = (
mock.call_args[1].get("files")
if "files" in (mock.call_args[1] or {})
else mock.call_args[0][2]
)
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
# ---------------------------------------------------------------------------
# asyncio.to_thread verification — prove blocking calls run in parallel
# ---------------------------------------------------------------------------
async def test_client_download_file_runs_in_thread_parallel():
"""Calling download_file concurrently via asyncio.gather should overlap
blocking work on separate threads, proving to_thread is effective.
Strategy: patch _sync_download_file with a blocking time.sleep(0.2).
Launch 3 concurrent calls. Serial would take >=0.6s; parallel < 0.4s.
"""
from app.connectors.google_drive.client import GoogleDriveClient
block_seconds = 0.2
num_calls = 3
def _blocking_download(service, file_id, credentials):
time.sleep(block_seconds)
return b"fake-content", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient,
"_sync_download_file",
staticmethod(_blocking_download),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.download_file(f"file-{i}") for i in range(num_calls))
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"fake-content"
assert error is None
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"downloads are not running in parallel"
)
async def test_client_export_google_file_runs_in_thread_parallel():
"""Same strategy for export_google_file — verify to_thread parallelism."""
from app.connectors.google_drive.client import GoogleDriveClient
block_seconds = 0.2
num_calls = 3
def _blocking_export(service, file_id, mime_type, credentials):
time.sleep(block_seconds)
return b"exported", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient,
"_sync_export_google_file",
staticmethod(_blocking_export),
):
start = time.monotonic()
results = await asyncio.gather(
*(
client.export_google_file(f"file-{i}", "application/pdf")
for i in range(num_calls)
)
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"exported"
assert error is None
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"exports are not running in parallel"
)

View file

@ -0,0 +1,387 @@
"""Tests for Jira indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.jira_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.jira_indexer import (
_build_connector_doc,
index_jira_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
):
return {"key": issue_key, "id": issue_id, "title": title}
def _make_formatted_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
status: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"key": issue_key,
"id": issue_id,
"title": title,
"status": status,
"priority": priority,
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug")
formatted = _make_formatted_issue(
issue_key="ENG-42",
issue_id="4242",
title="Fix auth bug",
status="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix auth bug\n\nBody"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: 4242"
assert doc.unique_id == "ENG-42"
assert doc.document_type == DocumentType.JIRA_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "ENG-42"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "4242"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Jira Issue"
assert doc.metadata["connector_type"] == "Jira"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_jira_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_key=i.get("key", ""),
issue_id=i.get("id", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent"
)
client.close = AsyncMock()
return client
@pytest.fixture
def jira_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod,
"JiraHistoryConnector",
MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"jira_client": jira_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_jira_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks):
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_called_once()
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(jira_mocks):
await _run_index(jira_mocks)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(jira_mocks):
await _run_index(jira_mocks)
jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing key/title/content)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_key_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "", "id": "10002", "title": "No key"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "ENG-2", "id": "", "title": "Missing id used as title"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_no_content_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [
"# ENG-1: 10001\n\nContent",
"",
]
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks):
heartbeat_cb = AsyncMock()
await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_not_called()
async def test_no_issues_error_message_returns_success_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"No issues found in date range",
)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(jira_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Jira issues" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(jira_mocks):
jira_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(jira_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,374 @@
"""Tests for Linear indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.linear_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.linear_indexer import (
_build_connector_doc,
index_linear_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
):
return {"id": issue_id, "identifier": identifier, "title": title}
def _make_formatted_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
state: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"id": issue_id,
"identifier": identifier,
"title": title,
"state": state,
"priority": priority,
"description": "Some description",
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a Linear issue produces a ConnectorDocument with correct fields."""
issue = _make_issue(issue_id="abc-123", identifier="ENG-42", title="Fix login bug")
formatted = _make_formatted_issue(
issue_id="abc-123",
identifier="ENG-42",
title="Fix login bug",
state="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix login bug\n\nDescription here"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: Fix login bug"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.LINEAR_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "abc-123"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "Fix login bug"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Linear Issue"
assert doc.metadata["connector_type"] == "Linear"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-6
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_linear_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: (
f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
)
)
return client
@pytest.fixture
def linear_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_linear_issues."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
linear_client = _mock_linear_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod,
"LinearConnector",
MagicMock(return_value=linear_client),
)
monkeypatch.setattr(
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"linear_client": linear_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_linear_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(linear_mocks):
"""One valid issue is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_called_once()
call_args = linear_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.LINEAR_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(linear_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(linear_mocks)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(linear_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(linear_mocks)
linear_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing ID / title)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_id_are_skipped(linear_mocks):
"""Issues without id are skipped and not passed to the pipeline."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "", "identifier": "ENG-2", "title": "No ID"},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(linear_mocks):
"""Issues without title are skipped."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "id-2", "identifier": "ENG-2", "title": ""},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
"""Issues whose content hash matches an existing document are skipped."""
issues = [
_make_issue(issue_id="new-1", identifier="ENG-1", title="New"),
_make_issue(issue_id="dup-1", identifier="ENG-2", title="Dup"),
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(linear_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(linear_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues early return
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(linear_mocks):
"""When no issues are found, returns (0, 0, None) and pipeline is not called."""
linear_mocks["linear_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_not_called()
async def test_failed_docs_warning_in_result(linear_mocks):
"""When documents fail indexing, the warning includes the count."""
linear_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(linear_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,365 @@
"""Tests for Notion indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.notion_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.notion_indexer import (
_build_connector_doc,
index_notion_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(page_id: str = "page-1", title: str = "Test Page", content=None):
if content is None:
content = [{"type": "paragraph", "content": "Hello world", "children": []}]
return {"page_id": page_id, "title": title, "content": content}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a single Notion page produces a ConnectorDocument with correct fields."""
page = _make_page(page_id="abc-123", title="My Notion Page")
markdown = "# My Notion Page\n\nHello world"
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "My Notion Page"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.NOTION_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_title"] == "My Notion Page"
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Notion Page"
assert doc.metadata["connector_type"] == "Notion"
assert doc.fallback_summary is not None
assert "My Notion Page" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_page(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7 (full index_notion_pages tests)
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_notion_client(pages=None, skipped_count=0, legacy_token=False):
client = MagicMock()
client.get_all_pages = AsyncMock(return_value=pages if pages is not None else [])
client.get_skipped_content_count = MagicMock(return_value=skipped_count)
client.is_using_legacy_token = MagicMock(return_value=legacy_token)
client.close = AsyncMock()
client.set_retry_callback = MagicMock()
return client
@pytest.fixture
def notion_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_notion_pages."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
notion_client = _mock_notion_client(pages=[_make_page()])
monkeypatch.setattr(
_mod,
"NotionHistoryConnector",
MagicMock(return_value=notion_client),
)
monkeypatch.setattr(
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
monkeypatch.setattr(
_mod,
"process_blocks",
MagicMock(return_value="Converted markdown content"),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"notion_client": notion_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_notion_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_retry_callback=overrides.get("on_retry_callback"),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(notion_mocks):
"""One valid page is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_called_once()
call_args = notion_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.NOTION_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(notion_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(notion_mocks)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(notion_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(notion_mocks)
notion_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (no content / missing ID)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
"""Pages without page_id are skipped and not passed to the pipeline."""
pages = [
_make_page(page_id="valid-1"),
{
"title": "No ID page",
"content": [{"type": "paragraph", "content": "text", "children": []}],
},
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_pages_with_no_content_are_skipped(notion_mocks, monkeypatch):
"""Pages with empty content are skipped."""
pages = [
_make_page(page_id="valid-1"),
_make_page(page_id="empty-1", content=[]),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(notion_mocks, monkeypatch):
"""Pages whose content hash matches an existing document are skipped."""
pages = [
_make_page(page_id="new-1"),
_make_page(page_id="dup-1"),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(notion_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(notion_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Notion-specific warning messages
# ---------------------------------------------------------------------------
async def test_skipped_ai_content_warning_in_result(notion_mocks):
"""When Notion AI content was skipped, the warning message includes it."""
notion_mocks["notion_client"].get_skipped_content_count.return_value = 3
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "API limitation" in warning
async def test_legacy_token_warning_in_result(notion_mocks):
"""When using legacy token, the warning message includes a notice."""
notion_mocks["notion_client"].is_using_legacy_token.return_value = True
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "legacy token" in warning.lower()
async def test_failed_docs_warning_in_result(notion_mocks):
"""When documents fail indexing, the warning includes the count."""
notion_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "2 failed" in warning
# ---------------------------------------------------------------------------
# Slice 7: Empty pages early return
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(notion_mocks):
"""When no pages are found, returns (0, 0, None) and updates last_indexed."""
notion_mocks["notion_client"].get_all_pages.return_value = []
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_not_called()

View file

@ -0,0 +1,131 @@
"""Unit tests for IndexingPipelineService.create_placeholder_documents."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy.exc import IntegrityError
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_placeholder(**overrides) -> PlaceholderInfo:
defaults = {
"title": "Test Doc",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"unique_id": "file-001",
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return PlaceholderInfo(**defaults)
def _uid_hash(p: PlaceholderInfo) -> str:
return compute_identifier_hash(
p.document_type.value, p.unique_id, p.search_space_id
)
def _session_with_existing_hashes(existing: set[str] | None = None):
"""Build an AsyncMock session whose batch-query returns *existing* hashes."""
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = list(existing or [])
session.execute = AsyncMock(return_value=result)
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_empty_input_returns_zero_without_db_calls():
session = AsyncMock()
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([])
assert result == 0
session.execute.assert_not_awaited()
session.commit.assert_not_awaited()
async def test_creates_documents_with_pending_status_and_commits():
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
p = _make_placeholder(title="My File", unique_id="file-abc")
result = await pipeline.create_placeholder_documents([p])
assert result == 1
session.add.assert_called_once()
doc = session.add.call_args[0][0]
assert doc.title == "My File"
assert doc.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert doc.content == "Pending..."
assert DocumentStatus.is_state(doc.status, DocumentStatus.PENDING)
assert doc.search_space_id == 1
assert doc.connector_id == 42
session.commit.assert_awaited_once()
async def test_existing_documents_are_skipped():
"""Placeholders whose unique_identifier_hash already exists are not re-created."""
existing_p = _make_placeholder(unique_id="already-there")
new_p = _make_placeholder(unique_id="brand-new")
existing_hash = _uid_hash(existing_p)
session = _session_with_existing_hashes({existing_hash})
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([existing_p, new_p])
assert result == 1
doc = session.add.call_args[0][0]
assert doc.unique_identifier_hash == _uid_hash(new_p)
async def test_duplicate_unique_ids_within_input_are_deduped():
"""Same unique_id passed twice only produces one placeholder."""
p1 = _make_placeholder(unique_id="dup-id", title="First")
p2 = _make_placeholder(unique_id="dup-id", title="Second")
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([p1, p2])
assert result == 1
session.add.assert_called_once()
async def test_integrity_error_on_commit_returns_zero():
"""IntegrityError during commit (race condition) is swallowed gracefully."""
session = _session_with_existing_hashes(set())
session.commit = AsyncMock(side_effect=IntegrityError("dup", {}, None))
pipeline = IndexingPipelineService(session)
p = _make_placeholder()
result = await pipeline.create_placeholder_documents([p])
assert result == 0
session.rollback.assert_awaited_once()

View file

@ -3,6 +3,7 @@ import pytest
from app.db import DocumentType
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash,
)
@ -61,3 +62,23 @@ def test_different_content_produces_different_content_hash(make_connector_docume
doc_a = make_connector_document(source_markdown="Original content")
doc_b = make_connector_document(source_markdown="Updated content")
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
def test_compute_identifier_hash_matches_connector_doc_hash(make_connector_document):
"""Raw-args hash equals ConnectorDocument hash for equivalent inputs."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-123",
search_space_id=5,
)
raw_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-123", 5)
assert raw_hash == compute_unique_identifier_hash(doc)
def test_compute_identifier_hash_differs_for_different_inputs():
"""Different arguments produce different hashes."""
h1 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 1)
h2 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-2", 1)
h3 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 2)
h4 = compute_identifier_hash("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "file-1", 1)
assert len({h1, h2, h3, h4}) == 4

View file

@ -0,0 +1,80 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
return AsyncMock()
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
async def test_calls_prepare_then_index_per_document(pipeline, make_connector_document):
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
doc2 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-2",
search_space_id=1,
)
orm1 = MagicMock(spec=Document)
orm1.unique_identifier_hash = compute_unique_identifier_hash(doc1)
orm2 = MagicMock(spec=Document)
orm2.unique_identifier_hash = compute_unique_identifier_hash(doc2)
mock_llm = MagicMock()
pipeline.prepare_for_indexing = AsyncMock(return_value=[orm1, orm2])
pipeline.index = AsyncMock(side_effect=lambda doc, cdoc, llm: doc)
results = await pipeline.index_batch([doc1, doc2], mock_llm)
pipeline.prepare_for_indexing.assert_awaited_once_with([doc1, doc2])
assert pipeline.index.await_count == 2
assert results == [orm1, orm2]
async def test_empty_input_returns_empty(pipeline):
"""Empty connector_docs list returns empty result."""
pipeline.prepare_for_indexing = AsyncMock(return_value=[])
results = await pipeline.index_batch([], MagicMock())
assert results == []
async def test_skips_document_without_matching_connector_doc(
pipeline, make_connector_document
):
"""If prepare returns a doc whose hash has no matching ConnectorDocument, it's skipped."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
orphan_orm = MagicMock(spec=Document)
orphan_orm.unique_identifier_hash = "nonexistent-hash"
pipeline.prepare_for_indexing = AsyncMock(return_value=[orphan_orm])
pipeline.index = AsyncMock()
results = await pipeline.index_batch([doc1], MagicMock())
pipeline.index.assert_not_awaited()
assert results == []

View file

@ -0,0 +1,188 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_orm_doc(connector_doc, doc_id):
"""Create a MagicMock Document bound to a ConnectorDocument's hash."""
doc = MagicMock(spec=Document)
doc.id = doc_id
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.status = DocumentStatus.pending()
return doc
async def test_index_calls_embed_and_chunk_via_to_thread(
pipeline, make_connector_document, monkeypatch
):
"""index() runs embed_texts and chunk_text via asyncio.to_thread, not blocking the loop."""
to_thread_calls = []
original_to_thread = asyncio.to_thread
async def tracking_to_thread(func, *args, **kwargs):
to_thread_calls.append(func.__name__)
return await original_to_thread(func, *args, **kwargs)
monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Summary."),
)
mock_chunk = MagicMock(return_value=["chunk1"])
mock_chunk.__name__ = "chunk_text"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk,
)
mock_embed = MagicMock(
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
)
mock_embed.__name__ = "embed_texts"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
mock_embed,
)
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
document = MagicMock(spec=Document)
document.id = 1
document.status = DocumentStatus.pending()
await pipeline.index(document, connector_doc, llm=MagicMock())
assert "chunk_text" in to_thread_calls
assert "embed_texts" in to_thread_calls
def _mock_session_factory(orm_docs_by_id):
"""Replace get_celery_session_maker with a two-level callable.
get_celery_session_maker() -> session_maker
session_maker() -> async context manager yielding a mock session
"""
def _get_maker():
def _make_session():
session = MagicMock()
session.get = AsyncMock(
side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id)
)
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx
return _make_session
return _get_maker
async def test_batch_parallel_indexes_all_documents(
pipeline, make_connector_document, monkeypatch
):
"""index_batch_parallel indexes all documents and returns correct counts."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
index_calls = []
async def fake_index(self, document, connector_doc, llm):
index_calls.append(document.id)
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", fake_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=2
)
assert indexed == 3
assert failed == 0
assert sorted(index_calls) == [1, 2, 3]
async def test_batch_parallel_one_failure_does_not_affect_others(
pipeline, make_connector_document, monkeypatch
):
"""One document failure doesn't prevent other documents from being indexed."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
async def failing_index(self, document, connector_doc, llm):
if document.id == 2:
raise RuntimeError("LLM exploded")
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", failing_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=4
)
assert indexed == 2
assert failed == 1

View file

@ -0,0 +1,127 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_execute_side_effect(doc_by_hash: dict):
"""Return a side_effect for session.execute that resolves documents by hash."""
async def _side_effect(stmt):
result = MagicMock()
for h, doc in doc_by_hash.items():
if h in str(stmt.compile(compile_kwargs={"literal_binds": True})):
result.scalars.return_value.first.return_value = doc
return result
result.scalars.return_value.first.return_value = None
return result
return _side_effect
async def test_updates_hash_and_type_for_legacy_document(
pipeline, mock_session, make_connector_document
):
"""Legacy Composio document gets unique_identifier_hash and document_type updated."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-abc",
search_space_id=1,
)
legacy_hash = compute_identifier_hash("COMPOSIO_GMAIL_CONNECTOR", "msg-abc", 1)
native_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-abc", 1)
existing = MagicMock(spec=Document)
existing.unique_identifier_hash = legacy_hash
existing.document_type = DocumentType.COMPOSIO_GMAIL_CONNECTOR
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
assert existing.unique_identifier_hash == native_hash
assert existing.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
mock_session.commit.assert_awaited_once()
async def test_noop_when_no_legacy_document_exists(
pipeline, mock_session, make_connector_document
):
"""No updates when no legacy Composio document is found in DB."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-xyz",
search_space_id=1,
)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
mock_session.commit.assert_awaited_once()
async def test_skips_non_google_doc_types(
pipeline, mock_session, make_connector_document
):
"""Non-Google doc types have no legacy mapping and trigger no DB query."""
doc = make_connector_document(
document_type=DocumentType.SLACK_CONNECTOR,
unique_id="slack-123",
search_space_id=1,
)
await pipeline.migrate_legacy_docs([doc])
mock_session.execute.assert_not_awaited()
mock_session.commit.assert_awaited_once()
async def test_handles_all_three_google_types(
pipeline, mock_session, make_connector_document
):
"""Each native Google type correctly maps to its Composio legacy type."""
mappings = [
(DocumentType.GOOGLE_GMAIL_CONNECTOR, "COMPOSIO_GMAIL_CONNECTOR"),
(DocumentType.GOOGLE_CALENDAR_CONNECTOR, "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"),
(DocumentType.GOOGLE_DRIVE_FILE, "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"),
]
for native_type, expected_legacy in mappings:
doc = make_connector_document(
document_type=native_type,
unique_id="id-1",
search_space_id=1,
)
existing = MagicMock(spec=Document)
existing.document_type = DocumentType(expected_legacy)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
mock_session.commit = AsyncMock()
await pipeline.migrate_legacy_docs([doc])
assert existing.document_type == native_type

View file

@ -0,0 +1,110 @@
"""Unit tests for the duplicate-content safety logic in prepare_for_indexing.
Verifies that when an existing document's updated content matches another
document's content_hash, the system marks it as failed (for placeholders)
or leaves it untouched (for ready documents) never deletes.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_connector_doc(**overrides) -> ConnectorDocument:
defaults = {
"title": "Test Doc",
"source_markdown": "## Some new content",
"unique_id": "file-001",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
def _make_existing_doc(connector_doc: ConnectorDocument, *, status: dict) -> MagicMock:
"""Build a MagicMock that looks like an ORM Document with given status."""
doc = MagicMock(spec=Document)
doc.id = 999
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.content_hash = "old-placeholder-content-hash"
doc.title = connector_doc.title
doc.status = status
return doc
def _mock_session_for_dedup(existing_doc, *, has_duplicate: bool):
"""Build a session whose sequential execute() calls return:
1. The *existing_doc* for the unique_identifier_hash lookup.
2. A row (or None) for the duplicate content_hash check.
"""
session = AsyncMock()
existing_result = MagicMock()
existing_result.scalars.return_value.first.return_value = existing_doc
dup_result = MagicMock()
dup_result.scalars.return_value.first.return_value = 42 if has_duplicate else None
session.execute = AsyncMock(side_effect=[existing_result, dup_result])
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_pending_placeholder_with_duplicate_content_is_marked_failed():
"""A placeholder (pending) whose updated content duplicates another doc
must be marked as FAILED never deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.pending())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.FAILED)
assert "Duplicate content" in existing.status.get("reason", "")
session.delete.assert_not_called()
async def test_ready_document_with_duplicate_content_is_left_untouched():
"""A READY document whose updated content duplicates another doc
must be left completely untouched not failed, not deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.ready())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.READY)
session.delete.assert_not_called()

View file

@ -0,0 +1,133 @@
"""Unit tests for knowledge_search middleware helpers.
These test pure functions that don't require a database.
"""
import pytest
from app.agents.new_chat.middleware.knowledge_search import (
_build_document_xml,
_resolve_search_types,
)
pytestmark = pytest.mark.unit
# ── _resolve_search_types ──────────────────────────────────────────────
class TestResolveSearchTypes:
def test_returns_none_when_no_inputs(self):
assert _resolve_search_types(None, None) is None
def test_returns_none_when_both_empty(self):
assert _resolve_search_types([], []) is None
def test_includes_legacy_type_for_google_gmail(self):
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
assert "GOOGLE_GMAIL_CONNECTOR" in result
assert "COMPOSIO_GMAIL_CONNECTOR" in result
def test_includes_legacy_type_for_google_drive(self):
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
assert "GOOGLE_DRIVE_FILE" in result
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
def test_includes_legacy_type_for_google_calendar(self):
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
assert "GOOGLE_CALENDAR_CONNECTOR" in result
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
def test_no_legacy_expansion_for_unrelated_types(self):
result = _resolve_search_types(["FILE", "NOTE"], None)
assert set(result) == {"FILE", "NOTE"}
def test_combines_connectors_and_document_types(self):
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
def test_deduplicates(self):
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3

8877
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,6 @@
export const IPC_CHANNELS = {
OPEN_EXTERNAL: 'open-external',
GET_APP_VERSION: 'get-app-version',
DEEP_LINK: 'deep-link',
QUICK_ASK_TEXT: 'quick-ask-text',
} as const;

View file

@ -0,0 +1,19 @@
import { app, ipcMain, shell } from 'electron';
import { IPC_CHANNELS } from './channels';
export function registerIpcHandlers(): void {
ipcMain.on(IPC_CHANNELS.OPEN_EXTERNAL, (_event, url: string) => {
try {
const parsed = new URL(url);
if (parsed.protocol === 'http:' || parsed.protocol === 'https:') {
shell.openExternal(url);
}
} catch {
// invalid URL — ignore
}
});
ipcMain.handle(IPC_CHANNELS.GET_APP_VERSION, () => {
return app.getVersion();
});
}

View file

@ -1,258 +1,20 @@
import { app, BrowserWindow, shell, ipcMain, session, dialog, clipboard, Menu } from 'electron';
import path from 'path';
import { getPort } from 'get-port-please';
import { autoUpdater } from 'electron-updater';
import { app, BrowserWindow } from 'electron';
import { registerGlobalErrorHandlers, showErrorDialog } from './modules/errors';
import { startNextServer } from './modules/server';
import { createMainWindow } from './modules/window';
import { setupDeepLinks, handlePendingDeepLink } from './modules/deep-links';
import { setupAutoUpdater } from './modules/auto-updater';
import { setupMenu } from './modules/menu';
import { registerQuickAsk, unregisterQuickAsk } from './modules/quick-ask';
import { registerIpcHandlers } from './ipc/handlers';
function showErrorDialog(title: string, error: unknown): void {
const err = error instanceof Error ? error : new Error(String(error));
console.error(`${title}:`, err);
registerGlobalErrorHandlers();
if (app.isReady()) {
const detail = err.stack || err.message;
const buttonIndex = dialog.showMessageBoxSync({
type: 'error',
buttons: ['OK', process.platform === 'darwin' ? 'Copy Error' : 'Copy error'],
defaultId: 0,
noLink: true,
message: title,
detail,
});
if (buttonIndex === 1) {
clipboard.writeText(`${title}\n${detail}`);
}
} else {
dialog.showErrorBox(title, err.stack || err.message);
}
}
process.on('uncaughtException', (error) => {
showErrorDialog('Unhandled Error', error);
});
process.on('unhandledRejection', (reason) => {
showErrorDialog('Unhandled Promise Rejection', reason);
});
const isDev = !app.isPackaged;
let mainWindow: BrowserWindow | null = null;
let deepLinkUrl: string | null = null;
let serverPort: number = 3000; // overwritten at startup with a free port
const PROTOCOL = 'surfsense';
// Injected at compile time from .env via esbuild define
const HOSTED_FRONTEND_URL = process.env.HOSTED_FRONTEND_URL as string;
function getStandalonePath(): string {
if (isDev) {
return path.join(__dirname, '..', '..', 'surfsense_web', '.next', 'standalone', 'surfsense_web');
}
return path.join(process.resourcesPath, 'standalone');
}
async function waitForServer(url: string, maxRetries = 60): Promise<boolean> {
for (let i = 0; i < maxRetries; i++) {
try {
const res = await fetch(url);
if (res.ok || res.status === 404 || res.status === 500) return true;
} catch {
// not ready yet
}
await new Promise((r) => setTimeout(r, 500));
}
return false;
}
async function startNextServer(): Promise<void> {
if (isDev) return;
serverPort = await getPort({ port: 3000, portRange: [30_011, 50_000] });
console.log(`Selected port ${serverPort}`);
const standalonePath = getStandalonePath();
const serverScript = path.join(standalonePath, 'server.js');
// The standalone server.js reads PORT / HOSTNAME from process.env and
// uses process.chdir(__dirname). Running it via require() in the same
// process is the proven approach (avoids spawning a second Electron
// instance whose ASAR-patched fs breaks Next.js static file serving).
process.env.PORT = String(serverPort);
process.env.HOSTNAME = 'localhost';
process.env.NODE_ENV = 'production';
process.chdir(standalonePath);
require(serverScript);
const ready = await waitForServer(`http://localhost:${serverPort}`);
if (!ready) {
throw new Error('Next.js server failed to start within 30 s');
}
console.log(`Next.js server ready on port ${serverPort}`);
}
function createWindow() {
mainWindow = new BrowserWindow({
width: 1280,
height: 800,
minWidth: 800,
minHeight: 600,
webPreferences: {
preload: path.join(__dirname, 'preload.js'),
contextIsolation: true,
nodeIntegration: false,
sandbox: true,
webviewTag: false,
},
show: false,
titleBarStyle: 'hiddenInset',
});
mainWindow.once('ready-to-show', () => {
mainWindow?.show();
});
mainWindow.loadURL(`http://localhost:${serverPort}/login`);
// External links open in system browser, not in the Electron window
mainWindow.webContents.setWindowOpenHandler(({ url }) => {
if (url.startsWith('http://localhost')) {
return { action: 'allow' };
}
shell.openExternal(url);
return { action: 'deny' };
});
// Intercept backend OAuth redirects targeting the hosted web frontend
// and rewrite them to localhost so the user stays in the desktop app.
const filter = { urls: [`${HOSTED_FRONTEND_URL}/*`] };
session.defaultSession.webRequest.onBeforeRequest(filter, (details, callback) => {
const rewritten = details.url.replace(HOSTED_FRONTEND_URL, `http://localhost:${serverPort}`);
callback({ redirectURL: rewritten });
});
mainWindow.webContents.on('did-fail-load', (_event, errorCode, errorDescription, validatedURL) => {
console.error(`Failed to load ${validatedURL}: ${errorDescription} (${errorCode})`);
if (errorCode === -3) return; // ERR_ABORTED — normal during redirects
showErrorDialog('Page failed to load', new Error(`${errorDescription} (${errorCode})\n${validatedURL}`));
});
if (isDev) {
mainWindow.webContents.openDevTools();
}
mainWindow.on('closed', () => {
mainWindow = null;
});
}
// IPC handlers
ipcMain.on('open-external', (_event, url: string) => {
try {
const parsed = new URL(url);
if (parsed.protocol === 'http:' || parsed.protocol === 'https:') {
shell.openExternal(url);
}
} catch {
// invalid URL — ignore
}
});
ipcMain.handle('get-app-version', () => {
return app.getVersion();
});
// Deep link handling
function handleDeepLink(url: string) {
if (!url.startsWith(`${PROTOCOL}://`)) return;
deepLinkUrl = url;
if (!mainWindow) return;
// Rewrite surfsense:// deep link to localhost so TokenHandler.tsx processes it
const parsed = new URL(url);
if (parsed.hostname === 'auth' && parsed.pathname === '/callback') {
const params = parsed.searchParams.toString();
mainWindow.loadURL(`http://localhost:${serverPort}/auth/callback?${params}`);
}
mainWindow.show();
mainWindow.focus();
}
// Single instance lock — second instance passes deep link to first
const gotTheLock = app.requestSingleInstanceLock();
if (!gotTheLock) {
if (!setupDeepLinks()) {
app.quit();
} else {
app.on('second-instance', (_event, argv) => {
// Windows/Linux: deep link URL is in argv
const url = argv.find((arg) => arg.startsWith(`${PROTOCOL}://`));
if (url) handleDeepLink(url);
if (mainWindow) {
if (mainWindow.isMinimized()) mainWindow.restore();
mainWindow.focus();
}
});
}
// macOS: deep link arrives via open-url event
app.on('open-url', (event, url) => {
event.preventDefault();
handleDeepLink(url);
});
// Register surfsense:// protocol
if (process.defaultApp) {
if (process.argv.length >= 2) {
app.setAsDefaultProtocolClient(PROTOCOL, process.execPath, [path.resolve(process.argv[1])]);
}
} else {
app.setAsDefaultProtocolClient(PROTOCOL);
}
function setupAutoUpdater() {
if (isDev) return;
autoUpdater.autoDownload = true;
autoUpdater.on('update-available', (info) => {
console.log(`Update available: ${info.version}`);
});
autoUpdater.on('update-downloaded', (info) => {
console.log(`Update downloaded: ${info.version}`);
dialog.showMessageBox({
type: 'info',
buttons: ['Restart', 'Later'],
defaultId: 0,
title: 'Update Ready',
message: `Version ${info.version} has been downloaded. Restart to apply the update.`,
}).then(({ response }) => {
if (response === 0) {
autoUpdater.quitAndInstall();
}
});
});
autoUpdater.on('error', (err) => {
console.error('Auto-updater error:', err);
});
autoUpdater.checkForUpdates();
}
function setupMenu() {
const isMac = process.platform === 'darwin';
const template: Electron.MenuItemConstructorOptions[] = [
...(isMac ? [{ role: 'appMenu' as const }] : []),
{ role: 'fileMenu' as const },
{ role: 'editMenu' as const },
{ role: 'viewMenu' as const },
{ role: 'windowMenu' as const },
];
Menu.setApplicationMenu(Menu.buildFromTemplate(template));
}
registerIpcHandlers();
// App lifecycle
app.whenReady().then(async () => {
@ -264,18 +26,15 @@ app.whenReady().then(async () => {
setTimeout(() => app.quit(), 0);
return;
}
createWindow();
createMainWindow();
registerQuickAsk();
setupAutoUpdater();
// If a deep link was received before the window was ready, handle it now
if (deepLinkUrl) {
handleDeepLink(deepLinkUrl);
deepLinkUrl = null;
}
handlePendingDeepLink();
app.on('activate', () => {
if (BrowserWindow.getAllWindows().length === 0) {
createWindow();
createMainWindow();
}
});
});
@ -287,5 +46,5 @@ app.on('window-all-closed', () => {
});
app.on('will-quit', () => {
// Server runs in-process — no child process to kill
unregisterQuickAsk();
});

View file

@ -0,0 +1,33 @@
import { app, dialog } from 'electron';
import { autoUpdater } from 'electron-updater';
export function setupAutoUpdater(): void {
if (!app.isPackaged) return;
autoUpdater.autoDownload = true;
autoUpdater.on('update-available', (info) => {
console.log(`Update available: ${info.version}`);
});
autoUpdater.on('update-downloaded', (info) => {
console.log(`Update downloaded: ${info.version}`);
dialog.showMessageBox({
type: 'info',
buttons: ['Restart', 'Later'],
defaultId: 0,
title: 'Update Ready',
message: `Version ${info.version} has been downloaded. Restart to apply the update.`,
}).then(({ response }) => {
if (response === 0) {
autoUpdater.quitAndInstall();
}
});
});
autoUpdater.on('error', (err) => {
console.log('Auto-updater: update check skipped —', err.message?.split('\n')[0]);
});
autoUpdater.checkForUpdates().catch(() => {});
}

View file

@ -0,0 +1,66 @@
import { app } from 'electron';
import path from 'path';
import { getMainWindow } from './window';
import { getServerPort } from './server';
const PROTOCOL = 'surfsense';
let deepLinkUrl: string | null = null;
function handleDeepLink(url: string) {
if (!url.startsWith(`${PROTOCOL}://`)) return;
deepLinkUrl = url;
const win = getMainWindow();
if (!win) return;
const parsed = new URL(url);
if (parsed.hostname === 'auth' && parsed.pathname === '/callback') {
const params = parsed.searchParams.toString();
win.loadURL(`http://localhost:${getServerPort()}/auth/callback?${params}`);
}
win.show();
win.focus();
}
export function setupDeepLinks(): boolean {
const gotTheLock = app.requestSingleInstanceLock();
if (!gotTheLock) {
return false;
}
app.on('second-instance', (_event, argv) => {
const url = argv.find((arg) => arg.startsWith(`${PROTOCOL}://`));
if (url) handleDeepLink(url);
const win = getMainWindow();
if (win) {
if (win.isMinimized()) win.restore();
win.focus();
}
});
app.on('open-url', (event, url) => {
event.preventDefault();
handleDeepLink(url);
});
if (process.defaultApp) {
if (process.argv.length >= 2) {
app.setAsDefaultProtocolClient(PROTOCOL, process.execPath, [path.resolve(process.argv[1])]);
}
} else {
app.setAsDefaultProtocolClient(PROTOCOL);
}
return true;
}
export function handlePendingDeepLink(): void {
if (deepLinkUrl) {
handleDeepLink(deepLinkUrl);
deepLinkUrl = null;
}
}

View file

@ -0,0 +1,33 @@
import { app, clipboard, dialog } from 'electron';
export function showErrorDialog(title: string, error: unknown): void {
const err = error instanceof Error ? error : new Error(String(error));
console.error(`${title}:`, err);
if (app.isReady()) {
const detail = err.stack || err.message;
const buttonIndex = dialog.showMessageBoxSync({
type: 'error',
buttons: ['OK', process.platform === 'darwin' ? 'Copy Error' : 'Copy error'],
defaultId: 0,
noLink: true,
message: title,
detail,
});
if (buttonIndex === 1) {
clipboard.writeText(`${title}\n${detail}`);
}
} else {
dialog.showErrorBox(title, err.stack || err.message);
}
}
export function registerGlobalErrorHandlers(): void {
process.on('uncaughtException', (error) => {
showErrorDialog('Unhandled Error', error);
});
process.on('unhandledRejection', (reason) => {
showErrorDialog('Unhandled Promise Rejection', reason);
});
}

View file

@ -0,0 +1,13 @@
import { Menu } from 'electron';
export function setupMenu(): void {
const isMac = process.platform === 'darwin';
const template: Electron.MenuItemConstructorOptions[] = [
...(isMac ? [{ role: 'appMenu' as const }] : []),
{ role: 'fileMenu' as const },
{ role: 'editMenu' as const },
{ role: 'viewMenu' as const },
{ role: 'windowMenu' as const },
];
Menu.setApplicationMenu(Menu.buildFromTemplate(template));
}

View file

@ -0,0 +1,108 @@
import { BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron';
import path from 'path';
import { IPC_CHANNELS } from '../ipc/channels';
import { getServerPort } from './server';
const SHORTCUT = 'CommandOrControl+Option+S';
let quickAskWindow: BrowserWindow | null = null;
let pendingText = '';
function hideQuickAsk(): void {
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
quickAskWindow.hide();
}
}
function clampToScreen(x: number, y: number, w: number, h: number): { x: number; y: number } {
const display = screen.getDisplayNearestPoint({ x, y });
const { x: dx, y: dy, width: dw, height: dh } = display.workArea;
return {
x: Math.max(dx, Math.min(x, dx + dw - w)),
y: Math.max(dy, Math.min(y, dy + dh - h)),
};
}
function createQuickAskWindow(x: number, y: number): BrowserWindow {
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
quickAskWindow.setPosition(x, y);
quickAskWindow.show();
quickAskWindow.focus();
return quickAskWindow;
}
quickAskWindow = new BrowserWindow({
width: 450,
height: 550,
x,
y,
...(process.platform === 'darwin'
? { type: 'panel' as const }
: { type: 'toolbar' as const, alwaysOnTop: true }),
resizable: true,
fullscreenable: false,
maximizable: false,
webPreferences: {
preload: path.join(__dirname, 'preload.js'),
contextIsolation: true,
nodeIntegration: false,
sandbox: true,
},
show: false,
skipTaskbar: true,
});
quickAskWindow.loadURL(`http://localhost:${getServerPort()}/dashboard`);
quickAskWindow.once('ready-to-show', () => {
quickAskWindow?.show();
});
quickAskWindow.webContents.on('before-input-event', (_event, input) => {
if (input.key === 'Escape') hideQuickAsk();
});
quickAskWindow.webContents.setWindowOpenHandler(({ url }) => {
if (url.startsWith('http://localhost')) {
return { action: 'allow' };
}
shell.openExternal(url);
return { action: 'deny' };
});
quickAskWindow.on('closed', () => {
quickAskWindow = null;
});
return quickAskWindow;
}
export function registerQuickAsk(): void {
const ok = globalShortcut.register(SHORTCUT, () => {
if (quickAskWindow && !quickAskWindow.isDestroyed() && quickAskWindow.isVisible()) {
hideQuickAsk();
return;
}
const text = clipboard.readText().trim();
if (!text) return;
pendingText = text;
const cursor = screen.getCursorScreenPoint();
const pos = clampToScreen(cursor.x, cursor.y, 450, 550);
createQuickAskWindow(pos.x, pos.y);
});
if (!ok) {
console.log(`Quick-ask: failed to register ${SHORTCUT}`);
}
ipcMain.handle(IPC_CHANNELS.QUICK_ASK_TEXT, () => {
const text = pendingText;
pendingText = '';
return text;
});
}
export function unregisterQuickAsk(): void {
globalShortcut.unregister(SHORTCUT);
}

View file

@ -0,0 +1,53 @@
import path from 'path';
import { app } from 'electron';
import { getPort } from 'get-port-please';
const isDev = !app.isPackaged;
let serverPort = 3000;
export function getServerPort(): number {
return serverPort;
}
function getStandalonePath(): string {
if (isDev) {
return path.join(__dirname, '..', '..', 'surfsense_web', '.next', 'standalone', 'surfsense_web');
}
return path.join(process.resourcesPath, 'standalone');
}
async function waitForServer(url: string, maxRetries = 60): Promise<boolean> {
for (let i = 0; i < maxRetries; i++) {
try {
const res = await fetch(url);
if (res.ok || res.status === 404 || res.status === 500) return true;
} catch {
// not ready yet
}
await new Promise((r) => setTimeout(r, 500));
}
return false;
}
export async function startNextServer(): Promise<void> {
if (isDev) return;
serverPort = await getPort({ port: 3000, portRange: [30_011, 50_000] });
console.log(`Selected port ${serverPort}`);
const standalonePath = getStandalonePath();
const serverScript = path.join(standalonePath, 'server.js');
process.env.PORT = String(serverPort);
process.env.HOSTNAME = '0.0.0.0';
process.env.NODE_ENV = 'production';
process.chdir(standalonePath);
require(serverScript);
const ready = await waitForServer(`http://localhost:${serverPort}`);
if (!ready) {
throw new Error('Next.js server failed to start within 30 s');
}
console.log(`Next.js server ready on port ${serverPort}`);
}

View file

@ -0,0 +1,67 @@
import { app, BrowserWindow, shell, session } from 'electron';
import path from 'path';
import { showErrorDialog } from './errors';
import { getServerPort } from './server';
const isDev = !app.isPackaged;
const HOSTED_FRONTEND_URL = process.env.HOSTED_FRONTEND_URL as string;
let mainWindow: BrowserWindow | null = null;
export function getMainWindow(): BrowserWindow | null {
return mainWindow;
}
export function createMainWindow(): BrowserWindow {
mainWindow = new BrowserWindow({
width: 1280,
height: 800,
minWidth: 800,
minHeight: 600,
webPreferences: {
preload: path.join(__dirname, 'preload.js'),
contextIsolation: true,
nodeIntegration: false,
sandbox: true,
webviewTag: false,
},
show: false,
titleBarStyle: 'hiddenInset',
});
mainWindow.once('ready-to-show', () => {
mainWindow?.show();
});
mainWindow.loadURL(`http://localhost:${getServerPort()}/dashboard`);
mainWindow.webContents.setWindowOpenHandler(({ url }) => {
if (url.startsWith('http://localhost')) {
return { action: 'allow' };
}
shell.openExternal(url);
return { action: 'deny' };
});
const filter = { urls: [`${HOSTED_FRONTEND_URL}/*`] };
session.defaultSession.webRequest.onBeforeRequest(filter, (details, callback) => {
const rewritten = details.url.replace(HOSTED_FRONTEND_URL, `http://localhost:${getServerPort()}`);
callback({ redirectURL: rewritten });
});
mainWindow.webContents.on('did-fail-load', (_event, errorCode, errorDescription, validatedURL) => {
console.error(`Failed to load ${validatedURL}: ${errorDescription} (${errorCode})`);
if (errorCode === -3) return;
showErrorDialog('Page failed to load', new Error(`${errorDescription} (${errorCode})\n${validatedURL}`));
});
if (isDev) {
mainWindow.webContents.openDevTools();
}
mainWindow.on('closed', () => {
mainWindow = null;
});
return mainWindow;
}

View file

@ -1,4 +1,5 @@
const { contextBridge, ipcRenderer } = require('electron');
const { IPC_CHANNELS } = require('./ipc/channels');
contextBridge.exposeInMainWorld('electronAPI', {
versions: {
@ -7,13 +8,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
chrome: process.versions.chrome,
platform: process.platform,
},
openExternal: (url: string) => ipcRenderer.send('open-external', url),
getAppVersion: () => ipcRenderer.invoke('get-app-version'),
openExternal: (url: string) => ipcRenderer.send(IPC_CHANNELS.OPEN_EXTERNAL, url),
getAppVersion: () => ipcRenderer.invoke(IPC_CHANNELS.GET_APP_VERSION),
onDeepLink: (callback: (url: string) => void) => {
const listener = (_event: unknown, url: string) => callback(url);
ipcRenderer.on('deep-link', listener);
ipcRenderer.on(IPC_CHANNELS.DEEP_LINK, listener);
return () => {
ipcRenderer.removeListener('deep-link', listener);
ipcRenderer.removeListener(IPC_CHANNELS.DEEP_LINK, listener);
};
},
getQuickAskText: () => ipcRenderer.invoke(IPC_CHANNELS.QUICK_ASK_TEXT),
});

View file

@ -1,8 +1,14 @@
import { loader } from "fumadocs-core/source";
import type { Metadata } from "next";
import { changelog } from "@/.source/server";
import { formatDate } from "@/lib/utils";
import { getMDXComponents } from "@/mdx-components";
export const metadata: Metadata = {
title: "Changelog | SurfSense",
description: "See what's new in SurfSense.",
};
const source = loader({
baseUrl: "/changelog",
source: changelog.toFumadocsSource(),

View file

@ -1,6 +1,11 @@
import React from "react";
import type { Metadata } from "next";
import { ContactFormGridWithDetails } from "@/components/contact/contact-form";
export const metadata: Metadata = {
title: "Contact | SurfSense",
description: "Get in touch with the SurfSense team.",
};
const page = () => {
return (
<div>

View file

@ -5,7 +5,7 @@ import { AnimatePresence, motion } from "motion/react";
import Link from "next/link";
import { useRouter } from "next/navigation";
import { useTranslations } from "next-intl";
import { useEffect, useState } from "react";
import { useState } from "react";
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
import { Spinner } from "@/components/ui/spinner";
import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors";
@ -25,15 +25,10 @@ export function LocalLoginForm() {
title: null,
message: null,
});
const [authType, setAuthType] = useState<string | null>(null);
const authType = AUTH_TYPE;
const router = useRouter();
const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom);
useEffect(() => {
// Get the auth type from centralized config
setAuthType(AUTH_TYPE);
}, []);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError({ title: null, message: null }); // Clear any previous errors

View file

@ -1,6 +1,11 @@
import React from "react";
import type { Metadata } from "next";
import PricingBasic from "@/components/pricing/pricing-section";
export const metadata: Metadata = {
title: "Pricing | SurfSense",
description: "Explore SurfSense plans and pricing options.",
};
const page = () => {
return (
<div>

View file

@ -43,9 +43,12 @@ export default function RegisterPage() {
}
}, [router]);
const handleSubmit = async (e: React.FormEvent) => {
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
submitForm();
};
const submitForm = async () => {
// Form validation
if (password !== confirmPassword) {
setError({ title: t("password_mismatch"), message: t("passwords_no_match_desc") });
@ -140,7 +143,7 @@ export default function RegisterPage() {
if (shouldRetry(errorCode)) {
toastOptions.action = {
label: tCommon("retry"),
onClick: () => handleSubmit(e),
onClick: () => submitForm(),
};
}

View file

@ -63,7 +63,7 @@ export function DocumentTypeChip({ type, className }: { type: string; className?
checkTruncation();
window.addEventListener("resize", checkTruncation);
return () => window.removeEventListener("resize", checkTruncation);
}, []);
}, [type]);
const chip = (
<span

View file

@ -1,6 +1,6 @@
"use client";
import { ListFilter, Search, Upload, X } from "lucide-react";
import { FolderPlus, ListFilter, Search, Upload, X } from "lucide-react";
import { useTranslations } from "next-intl";
import React, { useCallback, useMemo, useRef, useState } from "react";
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
@ -8,6 +8,7 @@ import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Input } from "@/components/ui/input";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon";
@ -17,12 +18,14 @@ export function DocumentsFilters({
searchValue,
onToggleType,
activeTypes,
onCreateFolder,
}: {
typeCounts: Partial<Record<DocumentTypeEnum, number>>;
onSearch: (v: string) => void;
searchValue: string;
onToggleType: (type: DocumentTypeEnum, checked: boolean) => void;
activeTypes: DocumentTypeEnum[];
onCreateFolder?: () => void;
}) {
const t = useTranslations("documents");
const id = React.useId();
@ -194,6 +197,23 @@ export function DocumentsFilters({
)}
</div>
{/* New Folder Button */}
{onCreateFolder && (
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
className="h-9 w-9 shrink-0 border-dashed border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
onClick={onCreateFolder}
>
<FolderPlus size={14} />
</Button>
</TooltipTrigger>
<TooltipContent>New folder</TooltipContent>
</Tooltip>
)}
{/* Upload Button */}
<Button
data-joyride="upload-button"

View file

@ -473,14 +473,14 @@ export function DocumentsTableShell({
}, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]);
const bulkDeleteBar = hasDeletableSelection ? (
<div className="flex items-center justify-center py-1.5 border-b border-border/50 bg-destructive/5 shrink-0 animate-in fade-in slide-in-from-top-1 duration-150">
<div className="absolute inset-x-0 top-0 z-10 flex items-center justify-center py-1 pointer-events-none animate-in fade-in duration-150">
<button
type="button"
onClick={() => setBulkDeleteConfirmOpen(true)}
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors"
className="pointer-events-auto flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-lg text-xs font-medium hover:bg-destructive/90 transition-colors"
>
<Trash2 size={12} />
Delete ({deletableSelectedIds.length} selected)
Delete {deletableSelectedIds.length} {deletableSelectedIds.length === 1 ? "item" : "items"}
</button>
</div>
) : null;
@ -526,7 +526,6 @@ export function DocumentsTableShell({
</TableRow>
</TableHeader>
</Table>
{bulkDeleteBar}
{loading ? (
<div className="flex-1 overflow-auto">
<Table className="table-fixed w-full">
@ -594,7 +593,8 @@ export function DocumentsTableShell({
)}
</div>
) : (
<div ref={desktopScrollRef} className="flex-1 overflow-auto">
<div ref={desktopScrollRef} className="flex-1 overflow-auto relative">
{bulkDeleteBar}
<Table className="table-fixed w-full">
<TableBody>
{sorted.map((doc) => {
@ -788,9 +788,6 @@ export function DocumentsTableShell({
)}
</div>
{/* Mobile bulk delete bar */}
<div className="md:hidden">{bulkDeleteBar}</div>
{/* Mobile Card View */}
{loading ? (
<div className="md:hidden divide-y divide-border/50 flex-1 overflow-auto">
@ -846,8 +843,9 @@ export function DocumentsTableShell({
) : (
<div
ref={mobileScrollRef}
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto"
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto relative"
>
{bulkDeleteBar}
{sorted.map((doc) => {
const isMentioned = mentionedDocIds?.has(doc.id) ?? false;
const statusState = doc.status?.state ?? "ready";

View file

@ -24,8 +24,7 @@ import {
} from "@/components/ui/dropdown-menu";
import type { Document } from "./types";
// Only FILE and NOTE document types can be edited
const EDITABLE_DOCUMENT_TYPES = ["FILE", "NOTE"] as const;
const EDITABLE_DOCUMENT_TYPES = ["NOTE"] as const;
// SURFSENSE_DOCS are system-managed and cannot be deleted
const NON_DELETABLE_DOCUMENT_TYPES = ["SURFSENSE_DOCS"] as const;
@ -47,20 +46,14 @@ export function RowActions({
document.document_type as (typeof EDITABLE_DOCUMENT_TYPES)[number]
);
// Documents in "pending" or "processing" state should show disabled delete
const isBeingProcessed =
document.status?.state === "pending" || document.status?.state === "processing";
// FILE documents that failed processing cannot be edited
const isFileFailed = document.document_type === "FILE" && document.status?.state === "failed";
// SURFSENSE_DOCS are system-managed and should not show delete at all
const shouldShowDelete = !NON_DELETABLE_DOCUMENT_TYPES.includes(
document.document_type as (typeof NON_DELETABLE_DOCUMENT_TYPES)[number]
);
// Edit is disabled while processing OR for failed FILE documents
const isEditDisabled = isBeingProcessed || isFileFailed;
const isEditDisabled = isBeingProcessed;
const isDeleteDisabled = isBeingProcessed;
const handleDelete = async () => {

View file

@ -0,0 +1,133 @@
"use client";
import { motion } from "motion/react";
import { Skeleton } from "@/components/ui/skeleton";
export default function Loading() {
return (
<motion.div
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ duration: 0.3 }}
className="w-full px-6 py-4 space-y-6 min-h-[calc(100vh-64px)]"
>
{/* Summary Dashboard Skeleton */}
<motion.div
className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
>
{[...Array(4)].map((_, i) => (
<div key={i} className="rounded-lg border p-4">
<div className="flex flex-row items-center justify-between space-y-0 pb-2">
<Skeleton className="h-4 w-24" />
<Skeleton className="h-4 w-4 rounded-full" />
</div>
<div className="space-y-2">
<Skeleton className="h-8 w-16" />
<Skeleton className="h-3 w-32" />
</div>
</div>
))}
</motion.div>
{/* Header Section Skeleton */}
<motion.div
className="flex items-center justify-between"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.1 }}
>
<div className="space-y-2">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-4 w-64" />
</div>
<Skeleton className="h-9 w-24" />
</motion.div>
{/* Filters Skeleton */}
<motion.div
className="flex flex-wrap items-center justify-start gap-3 w-full"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
>
<div className="flex items-center gap-3 flex-wrap w-full sm:w-auto">
<Skeleton className="h-9 w-full sm:w-60" />
<Skeleton className="h-9 w-24" />
<Skeleton className="h-9 w-24" />
<Skeleton className="h-9 w-20" />
</div>
</motion.div>
{/* Table Skeleton */}
<motion.div
className="rounded-md border overflow-hidden"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.3 }}
>
{/* Table Header */}
<div className="border-b bg-muted/50 px-4 py-3 flex items-center gap-4">
<Skeleton className="h-4 w-4" />
<Skeleton className="h-4 w-16" />
<Skeleton className="h-4 w-20" />
<Skeleton className="h-4 w-24" />
<Skeleton className="h-4 flex-1" />
<Skeleton className="h-4 w-24" />
<Skeleton className="h-4 w-8" />
</div>
{/* Table Rows */}
{[...Array(6)].map((_, i) => (
<div key={i} className="border-b px-4 py-3 flex items-center gap-4 hover:bg-muted/50">
<Skeleton className="h-4 w-4" />
<Skeleton className="h-6 w-12 rounded-full" />
<Skeleton className="h-6 w-16 rounded-full" />
<div className="flex items-center gap-2">
<Skeleton className="h-4 w-4" />
<Skeleton className="h-4 w-20" />
</div>
<div className="flex-1 space-y-1">
<Skeleton className="h-4 w-32" />
<Skeleton className="h-3 w-48" />
</div>
<div className="space-y-1">
<Skeleton className="h-3 w-24" />
<Skeleton className="h-3 w-20" />
</div>
<Skeleton className="h-8 w-8" />
</div>
))}
</motion.div>
{/* Pagination Skeleton */}
<div className="flex items-center justify-between gap-8 mt-4">
<motion.div
className="flex items-center gap-3"
initial={{ opacity: 0, x: -20 }}
animate={{ opacity: 1, x: 0 }}
>
<Skeleton className="h-4 w-20 max-sm:sr-only" />
<Skeleton className="h-9 w-16" />
</motion.div>
<motion.div
className="flex grow justify-end"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2 }}
>
<Skeleton className="h-4 w-40" />
</motion.div>
<div className="flex items-center gap-2">
<Skeleton className="h-9 w-9" />
<Skeleton className="h-9 w-9" />
<Skeleton className="h-9 w-9" />
<Skeleton className="h-9 w-9" />
</div>
</div>
</motion.div>
);
}

View file

@ -0,0 +1,10 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function Loading() {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-4 p-4">
<Skeleton className="h-4 w-64" />
<Skeleton className="h-32 w-full max-w-2xl rounded-xl" />
</div>
);
}

View file

@ -30,8 +30,10 @@ import {
// extractWriteTodosFromContent,
} from "@/atoms/chat/plan-state.atom";
import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { membersAtom } from "@/atoms/members/members-query.atoms";
import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
import { Thread } from "@/components/assistant-ui/thread";
@ -74,6 +76,7 @@ import {
trackChatMessageSent,
trackChatResponseReceived,
} from "@/lib/posthog/events";
import Loading from "../loading";
/**
* After a tool produces output, mark any previously-decided interrupt tool
@ -188,6 +191,8 @@ export default function NewChatPage() {
const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom);
const closeReportPanel = useSetAtom(closeReportPanelAtom);
const closeEditorPanel = useSetAtom(closeEditorPanelAtom);
const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom);
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
@ -726,12 +731,10 @@ export default function NewChatPage() {
}
case "data-thread-title-update": {
// Handle thread title update from LLM-generated title
const titleData = parsed.data as { threadId: number; title: string };
if (titleData?.title && titleData?.threadId === currentThreadId) {
// Update current thread state with new title
setCurrentThread((prev) => (prev ? { ...prev, title: titleData.title } : prev));
// Invalidate thread list to refresh sidebar
updateChatTabTitle({ chatId: currentThreadId, title: titleData.title });
queryClient.invalidateQueries({
queryKey: ["threads", String(searchSpaceId)],
});
@ -739,6 +742,20 @@ export default function NewChatPage() {
break;
}
case "data-documents-updated": {
const docEvent = parsed.data as {
action: string;
document: AgentCreatedDocument;
};
if (docEvent?.document?.id) {
setAgentCreatedDocuments((prev) => {
if (prev.some((d) => d.id === docEvent.document.id)) return prev;
return [...prev, docEvent.document];
});
}
break;
}
case "data-interrupt-request": {
wasInterrupted = true;
const interruptData = parsed.data as Record<string, unknown>;
@ -1526,49 +1543,14 @@ export default function NewChatPage() {
// Show loading state only when loading an existing thread
if (isInitializing) {
return (
<div className="flex h-[calc(100dvh-64px)] flex-col bg-main-panel px-4">
<div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
{/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-56 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[85%]" />
<Skeleton className="h-4 w-[70%]" />
</div>
{/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-40 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[90%]" />
<Skeleton className="h-4 w-[60%]" />
</div>
</div>
{/* Input bar */}
<div className="sticky bottom-0 pb-6 bg-main-panel">
<div className="mx-auto w-full max-w-[44rem]">
<Skeleton className="h-24 w-full rounded-2xl" />
</div>
</div>
</div>
);
return <Loading />;
}
// Show error state only if we tried to load an existing thread but failed
// For new chats (urlChatId === 0), threadId being null is expected (lazy creation)
if (!threadId && urlChatId > 0) {
return (
<div className="flex h-[calc(100dvh-64px)] flex-col items-center justify-center gap-4">
<div className="flex h-full flex-col items-center justify-center gap-4">
<div className="text-destructive">Failed to load chat</div>
<button
type="button"
@ -1587,7 +1569,7 @@ export default function NewChatPage() {
return (
<AssistantRuntimeProvider runtime={runtime}>
<ThinkingStepsDataUI />
<div key={searchSpaceId} className="flex h-[calc(100dvh-64px)] overflow-hidden">
<div key={searchSpaceId} className="flex h-full overflow-hidden">
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
<Thread />
</div>

View file

@ -0,0 +1,45 @@
import { Skeleton } from "@/components/ui/skeleton";
export default function Loading() {
return (
<div className="flex h-full flex-col bg-main-panel px-4">
<div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
{/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-56 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[85%]" />
<Skeleton className="h-18 w-[40%]" />
</div>
{/* User message */}
<div className="flex gap-2 justify-end">
<Skeleton className="h-12 w-72 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-10 w-[30%]" />
<Skeleton className="h-4 w-[90%]" />
<Skeleton className="h-6 w-[60%]" />
</div>
{/* User message */}
<div className="flex gap-2 justify-end">
<Skeleton className="h-12 w-96 rounded-2xl" />
</div>
</div>
{/* Input bar */}
<div className="sticky bottom-0 pb-6 bg-main-panel">
<div className="mx-auto w-full max-w-[44rem]">
<Skeleton className="h-24 w-full rounded-2xl" />
</div>
</div>
</div>
);
}

View file

@ -1,15 +1,10 @@
"use client";
import { redirect } from "next/navigation";
import { useParams, useRouter } from "next/navigation";
import { useEffect } from "react";
export default function SearchSpaceDashboardPage() {
const router = useRouter();
const { search_space_id } = useParams();
useEffect(() => {
router.push(`/dashboard/${search_space_id}/new-chat`);
}, [router, search_space_id]);
return <></>;
export default async function SearchSpaceDashboardPage({
params,
}: {
params: Promise<{ search_space_id: string }>;
}) {
const { search_space_id } = await params;
redirect(`/dashboard/${search_space_id}/new-chat`);
}

View file

@ -188,13 +188,13 @@ export function TeamContent({ searchSpaceId }: TeamContentProps) {
[deleteMember, searchSpaceId]
);
const { data: roles = [] } = useQuery({
const { data: roles = [], isLoading: rolesLoading } = useQuery({
queryKey: cacheKeys.roles.all(searchSpaceId.toString()),
queryFn: () => rolesApiService.getRoles({ search_space_id: searchSpaceId }),
enabled: !!searchSpaceId,
});
const { data: invites = [] } = useQuery({
const { data: invites = [], isLoading: invitesLoading } = useQuery({
queryKey: cacheKeys.invites.all(searchSpaceId.toString()),
queryFn: () => invitesApiService.getInvites({ search_space_id: searchSpaceId }),
staleTime: 5 * 60 * 1000,
@ -294,15 +294,23 @@ export function TeamContent({ searchSpaceId }: TeamContentProps) {
return (
<div className="space-y-4 md:space-y-6">
<div className="flex items-center gap-2 flex-wrap">
{canInvite && (
<CreateInviteDialog
roles={roles}
onCreateInvite={handleCreateInvite}
searchSpaceId={searchSpaceId}
/>
{rolesLoading ? (
<Skeleton className="h-9 w-32 rounded-md" />
) : (
canInvite && (
<CreateInviteDialog
roles={roles}
onCreateInvite={handleCreateInvite}
searchSpaceId={searchSpaceId}
/>
)
)}
{canInvite && activeInvites.length > 0 && (
<AllInvitesDialog invites={activeInvites} onRevokeInvite={handleRevokeInvite} />
{invitesLoading ? (
<Skeleton className="h-9 w-32 rounded-md" />
) : (
canInvite && activeInvites.length > 0 && (
<AllInvitesDialog invites={activeInvites} onRevokeInvite={handleRevokeInvite} />
)
)}
<p className="text-xs md:text-sm text-muted-foreground whitespace-nowrap">
{members.length} {members.length === 1 ? "member" : "members"}
@ -595,6 +603,7 @@ function CreateInviteDialog({
});
} catch (error) {
console.error("Failed to create invite:", error);
toast.error("Failed to create invite. Please try again.");
} finally {
setCreating(false);
}

View file

@ -3,11 +3,11 @@
import posthog from "posthog-js";
import { useEffect } from "react";
export default function Error({
export default function ErrorPage({
error,
reset,
}: {
error: Error & { digest?: string };
error: globalThis.Error & { digest?: string };
reset: () => void;
}) {
useEffect(() => {

View file

@ -1,8 +1,9 @@
"use client";
import NextError from "next/error";
import "./globals.css";
import posthog from "posthog-js";
import { useEffect } from "react";
import { Button } from "@/components/ui/button";
export default function GlobalError({
error,
@ -18,10 +19,11 @@ export default function GlobalError({
return (
<html lang="en">
<body>
<NextError statusCode={0} />
<button type="button" onClick={reset}>
Try again
</button>
<div className="flex min-h-screen flex-col items-center justify-center gap-4 p-4">
<h2 className="text-xl font-semibold">Something went wrong</h2>
<p className="text-sm text-muted-foreground">An unexpected error occurred.</p>
<Button onClick={reset}>Try again</Button>
</div>
</body>
</html>
);

Some files were not shown because too many files have changed in this diff Show more