Merge pull request #840 from MODSetter/dev

feat: Google Drive HITL tools, team management UI, Daytona sandboxes, indexing pipeline hardening & test infrastructure
This commit is contained in:
Rohan Verma 2026-02-26 20:49:50 -08:00 committed by GitHub
commit 3ca401cb2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
184 changed files with 17831 additions and 9340 deletions

112
.cursor/skills/tdd/SKILL.md Normal file
View file

@ -0,0 +1,112 @@
---
name: tdd
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
---
---
name: tdd
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
---
# Test-Driven Development
## Philosophy
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
## Anti-Pattern: Horizontal Slices
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
This produces **crap tests**:
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
- You outrun your headlights, committing to test structure before understanding the implementation
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
```
WRONG (horizontal):
RED: test1, test2, test3, test4, test5
GREEN: impl1, impl2, impl3, impl4, impl5
RIGHT (vertical):
RED→GREEN: test1→impl1
RED→GREEN: test2→impl2
RED→GREEN: test3→impl3
...
```
## Workflow
### 1. Planning
Before writing any code:
- [ ] Confirm with user what interface changes are needed
- [ ] Confirm with user which behaviors to test (prioritize)
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
- [ ] Design interfaces for [testability](interface-design.md)
- [ ] List the behaviors to test (not implementation steps)
- [ ] Get user approval on the plan
Ask: "What should the public interface look like? Which behaviors are most important to test?"
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
### 2. Tracer Bullet
Write ONE test that confirms ONE thing about the system:
```
RED: Write test for first behavior → test fails
GREEN: Write minimal code to pass → test passes
```
This is your tracer bullet - proves the path works end-to-end.
### 3. Incremental Loop
For each remaining behavior:
```
RED: Write next test → fails
GREEN: Minimal code to pass → passes
```
Rules:
- One test at a time
- Only enough code to pass current test
- Don't anticipate future tests
- Keep tests focused on observable behavior
### 4. Refactor
After all tests pass, look for [refactor candidates](refactoring.md):
- [ ] Extract duplication
- [ ] Deepen modules (move complexity behind simple interfaces)
- [ ] Apply SOLID principles where natural
- [ ] Consider what new code reveals about existing code
- [ ] Run tests after each refactor step
**Never refactor while RED.** Get to GREEN first.
## Checklist Per Cycle
```
[ ] Test describes behavior, not implementation
[ ] Test uses public interface only
[ ] Test would survive internal refactor
[ ] Code is minimal for this test
[ ] No speculative features added
```

View file

@ -0,0 +1,33 @@
# Deep Modules
From "A Philosophy of Software Design":
**Deep module** = small interface + lots of implementation
```
┌─────────────────────┐
│ Small Interface │ ← Few methods, simple params
├─────────────────────┤
│ │
│ │
│ Deep Implementation│ ← Complex logic hidden
│ │
│ │
└─────────────────────┘
```
**Shallow module** = large interface + little implementation (avoid)
```
┌─────────────────────────────────┐
│ Large Interface │ ← Many methods, complex params
├─────────────────────────────────┤
│ Thin Implementation │ ← Just passes through
└─────────────────────────────────┘
```
When designing interfaces, ask:
- Can I reduce the number of methods?
- Can I simplify the parameters?
- Can I hide more complexity inside?

View file

@ -0,0 +1,33 @@
# Interface Design for Testability
Good interfaces make testing natural:
1. **Accept dependencies, don't create them**
```python
# Testable
def process_order(order, payment_gateway):
pass
# Hard to test
def process_order(order):
gateway = StripeGateway()
```
2. **Return results, don't produce side effects**
```python
# Testable
def calculate_discount(cart) -> float:
return discount
# Hard to test
def apply_discount(cart) -> None:
cart.total -= discount
```
3. **Small surface area**
* Fewer methods = fewer tests needed
* Fewer params = simpler test setup

View file

@ -0,0 +1,69 @@
# When to Mock
Mock at **system boundaries** only:
* External APIs (payment, email, etc.)
* Databases (sometimes - prefer test DB)
* Time/randomness
* File system (sometimes)
Don't mock:
* Your own classes/modules
* Internal collaborators
* Anything you control
## Designing for Mockability
At system boundaries, design interfaces that are easy to mock:
**1. Use dependency injection**
Pass external dependencies in rather than creating them internally:
```python
import os
# Easy to mock
def process_payment(order, payment_client):
return payment_client.charge(order.total)
# Hard to mock
def process_payment(order):
client = StripeClient(os.getenv("STRIPE_KEY"))
return client.charge(order.total)
```
**2. Prefer SDK-style interfaces over generic fetchers**
Create specific functions for each external operation instead of one generic function with conditional logic:
```python
import requests
# GOOD: Each function is independently mockable
class UserAPI:
def get_user(self, user_id):
return requests.get(f"/users/{user_id}")
def get_orders(self, user_id):
return requests.get(f"/users/{user_id}/orders")
def create_order(self, data):
return requests.post("/orders", json=data)
# BAD: Mocking requires conditional logic inside the mock
class GenericAPI:
def fetch(self, endpoint, method="GET", data=None):
return requests.request(method, endpoint, json=data)
```
The SDK approach means:
* Each mock returns one specific shape
* No conditional logic in test setup
* Easier to see which endpoints a test exercises
* Type safety per endpoint

View file

@ -0,0 +1,10 @@
# Refactor Candidates
After TDD cycle, look for:
- **Duplication** → Extract function/class
- **Long methods** → Break into private helpers (keep tests on public interface)
- **Shallow modules** → Combine or deepen
- **Feature envy** → Move logic to where data lives
- **Primitive obsession** → Introduce value objects
- **Existing code** the new code reveals as problematic

View file

@ -0,0 +1,60 @@
# Good and Bad Tests
## Good Tests
**Integration-style**: Test through real interfaces, not mocks of internal parts.
```python
# GOOD: Tests observable behavior
def test_user_can_checkout_with_valid_cart():
cart = create_cart()
cart.add(product)
result = checkout(cart, payment_method)
assert result.status == "confirmed"
```
Characteristics:
* Tests behavior users/callers care about
* Uses public API only
* Survives internal refactors
* Describes WHAT, not HOW
* One logical assertion per test
## Bad Tests
**Implementation-detail tests**: Coupled to internal structure.
```python
# BAD: Tests implementation details
def test_checkout_calls_payment_service_process():
mock_payment = MagicMock()
checkout(cart, mock_payment)
mock_payment.process.assert_called_with(cart.total)
```
Red flags:
* Mocking internal collaborators
* Testing private methods
* Asserting on call counts/order
* Test breaks when refactoring without behavior change
* Test name describes HOW not WHAT
* Verifying through external means instead of interface
```python
# BAD: Bypasses interface to verify
def test_create_user_saves_to_database():
create_user({"name": "Alice"})
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
assert row is not None
# GOOD: Verifies through interface
def test_create_user_makes_user_retrievable():
user = create_user({"name": "Alice"})
retrieved = get_user(user.id)
assert retrieved.name == "Alice"
```

View file

@ -260,6 +260,10 @@ ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
ENV NEXT_PUBLIC_ETL_SERVICE=DOCLING
# Daytona Sandbox (cloud code execution — no local server needed)
ENV DAYTONA_SANDBOX_ENABLED=FALSE
# DAYTONA_API_KEY, DAYTONA_API_URL, DAYTONA_TARGET: set at runtime for production.
# Electric SQL configuration (ELECTRIC_DATABASE_URL is built dynamically by entrypoint from these values)
ENV ELECTRIC_DB_USER=electric
ENV ELECTRIC_DB_PASSWORD=electric_password

View file

@ -65,6 +65,11 @@ services:
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
# Daytona Sandbox uncomment and set credentials to enable cloud code execution
# - DAYTONA_SANDBOX_ENABLED=TRUE
# - DAYTONA_API_KEY=${DAYTONA_API_KEY:-}
# - DAYTONA_API_URL=${DAYTONA_API_URL:-https://app.daytona.io/api}
# - DAYTONA_TARGET=${DAYTONA_TARGET:-us}
depends_on:
- db
- redis

View file

@ -232,6 +232,7 @@ echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
echo " TTS Service: ${TTS_SERVICE}"
echo " STT Service: ${STT_SERVICE}"
echo " Daytona Sandbox: ${DAYTONA_SANDBOX_ENABLED:-FALSE}"
echo "==========================================="
echo ""

View file

@ -167,37 +167,12 @@ LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=lsv2_pt_.....
LANGSMITH_PROJECT=surfsense
# Uvicorn Server Configuration
# Full documentation for Uvicorn options can be found at: https://www.uvicorn.org/#command-line-options
UVICORN_HOST="0.0.0.0"
UVICORN_PORT=8000
UVICORN_LOG_LEVEL=info
# OPTIONAL: Advanced Uvicorn Options (uncomment to use)
# UVICORN_PROXY_HEADERS=false
# UVICORN_FORWARDED_ALLOW_IPS="127.0.0.1"
# UVICORN_WORKERS=1
# UVICORN_ACCESS_LOG=true
# UVICORN_LOOP="auto"
# UVICORN_HTTP="auto"
# UVICORN_WS="auto"
# UVICORN_LIFESPAN="auto"
# UVICORN_LOG_CONFIG=""
# UVICORN_SERVER_HEADER=true
# UVICORN_DATE_HEADER=true
# UVICORN_LIMIT_CONCURRENCY=
# UVICORN_LIMIT_MAX_REQUESTS=
# UVICORN_TIMEOUT_KEEP_ALIVE=5
# UVICORN_TIMEOUT_NOTIFY=30
# UVICORN_SSL_KEYFILE=""
# UVICORN_SSL_CERTFILE=""
# UVICORN_SSL_KEYFILE_PASSWORD=""
# UVICORN_SSL_VERSION=""
# UVICORN_SSL_CERT_REQS=""
# UVICORN_SSL_CA_CERTS=""
# UVICORN_SSL_CIPHERS=""
# UVICORN_HEADERS=""
# UVICORN_USE_COLORS=true
# UVICORN_UDS=""
# UVICORN_FD=""
# UVICORN_ROOT_PATH=""
# Agent Specific Configuration
# Daytona Sandbox (secure cloud code execution for deep agent)
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
DAYTONA_SANDBOX_ENABLED=TRUE
DAYTONA_API_KEY=dtn_asdasfasfafas
DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files

View file

@ -6,6 +6,7 @@ __pycache__/
.flashrank_cache
surf_new_backend.egg-info/
podcasts/
sandbox_files/
temp_audio/
celerybeat-schedule*
celerybeat-schedule.*

View file

@ -0,0 +1,46 @@
"""102_add_enable_summary_to_connectors
Revision ID: 102
Revises: 101
Create Date: 2026-02-26
Adds enable_summary boolean column to search_source_connectors.
Defaults to False for all existing and new connectors so LLM-based
summary generation is opt-in.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "102"
down_revision: str | None = "101"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
conn = op.get_bind()
existing_columns = [
col["name"] for col in sa.inspect(conn).get_columns("search_source_connectors")
]
if "enable_summary" not in existing_columns:
op.add_column(
"search_source_connectors",
sa.Column(
"enable_summary",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("search_source_connectors", "enable_summary")

View file

@ -6,10 +6,14 @@ with configurable tools via the tools registry and configurable prompts
via NewLLMConfig.
"""
import asyncio
import logging
import time
from collections.abc import Sequence
from typing import Any
from deepagents import create_deep_agent
from deepagents.backends.protocol import SandboxBackendProtocol
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -25,6 +29,8 @@ from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
_perf_log = logging.getLogger("surfsense.perf")
# =============================================================================
# Connector Type Mapping
# =============================================================================
@ -128,6 +134,7 @@ async def create_surfsense_deep_agent(
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_backend: SandboxBackendProtocol | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -167,6 +174,9 @@ async def create_surfsense_deep_agent(
These are always added regardless of enabled/disabled settings.
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
secure code execution. When provided, the agent gets an
isolated ``execute`` tool for running shell commands.
Returns:
CompiledStateGraph: The configured deep agent
@ -205,32 +215,41 @@ async def create_surfsense_deep_agent(
additional_tools=[my_custom_tool]
)
"""
_t_agent_total = time.perf_counter()
# Discover available connectors and document types for this search space
# This enables dynamic tool docstrings that inform the LLM about what's actually available
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
# Get enabled search source connectors for this search space
connector_types = await connector_service.get_available_connectors(
search_space_id
)
if connector_types:
# Convert enum values to strings and also include mapped document types
available_connectors = _map_connectors_to_searchable_types(connector_types)
# Get document types that have at least one document indexed
available_document_types = await connector_service.get_available_document_types(
search_space_id
)
except Exception as e:
# Log but don't fail - fall back to all connectors if discovery fails
import logging
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0,
)
# Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE
# Extract the model's context window so tools can size their output.
_model_profile = getattr(llm, "profile", None)
_max_input_tokens: int | None = (
_model_profile.get("max_input_tokens")
if isinstance(_model_profile, dict)
else None
)
dependencies = {
"search_space_id": search_space_id,
"db_session": db_session,
@ -241,6 +260,7 @@ async def create_surfsense_deep_agent(
"thread_visibility": visibility,
"available_connectors": available_connectors,
"available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
}
# Disable Notion action tools if no Notion connector is configured
@ -269,35 +289,61 @@ async def create_surfsense_deep_agent(
modified_disabled_tools.extend(linear_tools)
# Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter()
tools = await build_tools_async(
dependencies=dependencies,
enabled_tools=enabled_tools,
disabled_tools=modified_disabled_tools,
additional_tools=list(additional_tools) if additional_tools else None,
)
_perf_log.info(
"[create_agent] build_tools_async in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Build system prompt based on agent_config
_t0 = time.perf_counter()
_sandbox_enabled = sandbox_backend is not None
if agent_config is not None:
# Use configurable prompt with settings from NewLLMConfig
system_prompt = build_configurable_system_prompt(
custom_system_instructions=agent_config.system_instructions,
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
# Create the deep agent with system prompt and checkpointer
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
agent = create_deep_agent(
# Build optional kwargs for the deep agent
deep_agent_kwargs: dict[str, Any] = {}
if sandbox_backend is not None:
deep_agent_kwargs["backend"] = sandbox_backend
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
create_deep_agent,
model=llm,
tools=tools,
system_prompt=system_prompt,
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
**deep_agent_kwargs,
)
_perf_log.info(
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
time.perf_counter() - _t0,
)
_perf_log.info(
"[create_agent] Total agent creation in %.3fs",
time.perf_counter() - _t_agent_total,
)
return agent

View file

@ -0,0 +1,275 @@
"""
Daytona sandbox provider for SurfSense deep agent.
Manages the lifecycle of sandboxed code execution environments.
Each conversation thread gets its own isolated sandbox instance
via the Daytona cloud API, identified by labels.
Files created during a session are persisted to local storage before
the sandbox is deleted so they remain downloadable after cleanup.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import shutil
from pathlib import Path
from daytona import (
CreateSandboxFromSnapshotParams,
Daytona,
DaytonaConfig,
SandboxState,
)
from daytona.common.errors import DaytonaError
from deepagents.backends.protocol import ExecuteResponse
from langchain_daytona import DaytonaSandbox
logger = logging.getLogger(__name__)
class _TimeoutAwareSandbox(DaytonaSandbox):
"""DaytonaSandbox subclass that accepts the per-command *timeout*
kwarg required by the deepagents middleware.
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
so deepagents raises *"This sandbox backend does not support
per-command timeout overrides"* on every first call. This thin
wrapper forwards the parameter to the Daytona SDK.
"""
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
t = timeout if timeout is not None else self._timeout
result = self._sandbox.process.exec(command, timeout=t)
return ExecuteResponse(
output=result.result,
exit_code=result.exit_code,
truncated=False,
)
async def aexecute(
self, command: str, *, timeout: int | None = None
) -> ExecuteResponse: # type: ignore[override]
return await asyncio.to_thread(self.execute, command, timeout=timeout)
_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
THREAD_LABEL_KEY = "surfsense_thread"
def is_sandbox_enabled() -> bool:
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
def _get_client() -> Daytona:
global _daytona_client
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
"""Find an existing sandbox for *thread_id*, or create a new one.
If an existing sandbox is found but is stopped/archived, it will be
restarted automatically before returning.
"""
client = _get_client()
labels = {THREAD_LABEL_KEY: thread_id}
try:
sandbox = client.find_one(labels=labels)
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
if sandbox.state in (
SandboxState.STOPPED,
SandboxState.STOPPING,
SandboxState.ARCHIVED,
):
logger.info("Starting stopped sandbox %s", sandbox.id)
sandbox.start(timeout=60)
logger.info("Sandbox %s is now started", sandbox.id)
elif sandbox.state in (
SandboxState.ERROR,
SandboxState.BUILD_FAILED,
SandboxState.DESTROYED,
):
logger.warning(
"Sandbox %s in unrecoverable state %s — creating a new one",
sandbox.id,
sandbox.state,
)
sandbox = client.create(
CreateSandboxFromSnapshotParams(language="python", labels=labels)
)
logger.info("Created replacement sandbox: %s", sandbox.id)
elif sandbox.state != SandboxState.STARTED:
sandbox.wait_for_sandbox_start(timeout=60)
except Exception:
logger.info("No existing sandbox for thread %s — creating one", thread_id)
sandbox = client.create(
CreateSandboxFromSnapshotParams(language="python", labels=labels)
)
logger.info("Created new sandbox: %s", sandbox.id)
return _TimeoutAwareSandbox(sandbox=sandbox)
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
"""Get or create a sandbox for a conversation thread.
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
Args:
thread_id: The conversation thread identifier.
Returns:
DaytonaSandbox connected to the sandbox.
"""
key = str(thread_id)
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
return sandbox
async def delete_sandbox(thread_id: int | str) -> None:
"""Delete the sandbox for a conversation thread."""
_sandbox_cache.pop(str(thread_id), None)
def _delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except DaytonaError:
logger.debug(
"No sandbox to delete for thread %s (already removed)", thread_id
)
return
try:
client.delete(sandbox)
logger.info("Sandbox deleted: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox for thread %s",
thread_id,
exc_info=True,
)
await asyncio.to_thread(_delete)
# ---------------------------------------------------------------------------
# Local file persistence
# ---------------------------------------------------------------------------
def _get_sandbox_files_dir() -> Path:
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
"""Map a sandbox-internal absolute path to a local filesystem path."""
relative = sandbox_path.lstrip("/")
return _get_sandbox_files_dir() / str(thread_id) / relative
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
"""Read a previously-persisted sandbox file from local storage.
Returns the file bytes, or *None* if the file does not exist locally.
"""
local = _local_path_for(thread_id, sandbox_path)
if local.is_file():
return local.read_bytes()
return None
def delete_local_sandbox_files(thread_id: int | str) -> None:
"""Remove all locally-persisted sandbox files for a thread."""
thread_dir = _get_sandbox_files_dir() / str(thread_id)
if thread_dir.is_dir():
shutil.rmtree(thread_dir, ignore_errors=True)
logger.info("Deleted local sandbox files for thread %s", thread_id)
async def persist_and_delete_sandbox(
thread_id: int | str,
sandbox_file_paths: list[str],
) -> None:
"""Download sandbox files to local storage, then delete the sandbox.
Each file in *sandbox_file_paths* is downloaded from the Daytona
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/``.
Per-file errors are logged but do **not** prevent the sandbox from
being deleted freeing Daytona storage is the priority.
"""
_sandbox_cache.pop(str(thread_id), None)
def _persist_and_delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except Exception:
logger.info(
"No sandbox found for thread %s — nothing to persist", thread_id
)
return
# Ensure the sandbox is running so we can download files
if sandbox.state != SandboxState.STARTED:
try:
sandbox.start(timeout=60)
except Exception:
logger.warning(
"Could not start sandbox %s for file download — deleting anyway",
sandbox.id,
exc_info=True,
)
with contextlib.suppress(Exception):
client.delete(sandbox)
return
for path in sandbox_file_paths:
try:
content: bytes = sandbox.fs.download_file(path)
local = _local_path_for(thread_id, path)
local.parent.mkdir(parents=True, exist_ok=True)
local.write_bytes(content)
logger.info("Persisted sandbox file %s%s", path, local)
except Exception:
logger.warning(
"Failed to persist sandbox file %s for thread %s",
path,
thread_id,
exc_info=True,
)
try:
client.delete(sandbox)
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox %s after persistence",
sandbox.id,
exc_info=True,
)
await asyncio.to_thread(_persist_and_delete)

View file

@ -645,6 +645,87 @@ However, from your video learning, it's important to note that asyncio is not su
</citation_instructions>
"""
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
SANDBOX_EXECUTION_INSTRUCTIONS = """
<code_execution>
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
## CRITICAL — CODE-FIRST RULE
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
- **Creating a chart, plot, graph, or visualization** Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
- **Data analysis, statistics, or computation** Write code to compute the answer. Do not do math by hand in text.
- **Generating or transforming files** (CSV, PDF, images, etc.) Write code to create the file.
- **Running, testing, or debugging code** Execute it in the sandbox.
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
Example (CORRECT):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Immediately execute Python code (matplotlib) to generate the pie chart
3. Return the downloadable file + brief description
Example (WRONG):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Print a text table with percentages and ask the user if they want a chart NEVER do this
## When to Use Code Execution
Use the sandbox when the task benefits from actually running code rather than just describing it:
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
- **Calculations**: Math, financial modeling, unit conversions, simulations
- **Code validation**: Run and test code snippets the user provides or asks about
- **File processing**: Parse, transform, or convert data files
- **Quick prototyping**: Demonstrate working code for the user's problem
- **Package exploration**: Install and test libraries the user is evaluating
## When NOT to Use Code Execution
Do not use the sandbox for:
- Answering factual questions from your own knowledge
- Summarizing or explaining concepts
- Simple formatting or text generation tasks
- Tasks that don't require running code to answer
## Package Management
- Use `pip install <package>` to install Python packages as needed
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
- Always verify a package installed successfully before using it
## Working Guidelines
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
- **Iterative approach**: For complex tasks, break work into steps write code, run it, check output, refine
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
- **Be efficient**: Install packages once per session. Combine related commands when possible.
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
## Sharing Generated Files
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
- Always describe what the file contains in your response text so the user knows what they are downloading.
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
## Data Analytics Best Practices
When the user asks you to analyze data:
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
2. Clean and validate before computing (handle nulls, check types)
3. Perform the analysis and present results clearly
4. Offer follow-up insights or visualizations when appropriate
</code_execution>
"""
# Anti-citation prompt - used when citations are disabled
# This explicitly tells the model NOT to include citations
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
@ -670,6 +751,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -678,10 +760,12 @@ def build_surfsense_system_prompt(
- Default system instructions
- Tools instructions (always included)
- Citation instructions enabled
- Sandbox execution instructions (when sandbox_enabled=True)
Args:
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
Returns:
Complete system prompt string
@ -691,7 +775,13 @@ def build_surfsense_system_prompt(
system_instructions = _get_system_instructions(visibility, today)
tools_instructions = _get_tools_instructions(visibility)
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
return system_instructions + tools_instructions + citation_instructions
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
def build_configurable_system_prompt(
@ -700,14 +790,16 @@ def build_configurable_system_prompt(
citations_enabled: bool = True,
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
The prompt is composed of three parts:
The prompt is composed of up to four parts:
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
4. Sandbox Execution Instructions - when sandbox_enabled=True
Args:
custom_system_instructions: Custom system instructions to use. If empty/None and
@ -719,6 +811,7 @@ def build_configurable_system_prompt(
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
Returns:
Complete system prompt string
@ -727,7 +820,6 @@ def build_configurable_system_prompt(
# Determine system instructions
if custom_system_instructions and custom_system_instructions.strip():
# Use custom instructions, injecting the date placeholder if present
system_instructions = custom_system_instructions.format(
resolved_today=resolved_today
)
@ -735,7 +827,6 @@ def build_configurable_system_prompt(
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
else:
# No system instructions (edge case)
system_instructions = ""
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
@ -748,7 +839,14 @@ def build_configurable_system_prompt(
else SURFSENSE_NO_CITATION_INSTRUCTIONS
)
return system_instructions + tools_instructions + citation_instructions
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
def get_default_system_instructions() -> str:

View file

@ -0,0 +1,11 @@
from app.agents.new_chat.tools.google_drive.create_file import (
create_create_google_drive_file_tool,
)
from app.agents.new_chat.tools.google_drive.trash_file import (
create_delete_google_drive_file_tool,
)
__all__ = [
"create_create_google_drive_file_tool",
"create_delete_google_drive_file_tool",
]

View file

@ -0,0 +1,239 @@
import logging
from typing import Any, Literal
from googleapiclient.errors import HttpError
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
_MIME_MAP: dict[str, str] = {
"google_doc": GOOGLE_DOC,
"google_sheet": GOOGLE_SHEET,
}
def create_create_google_drive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def create_google_drive_file(
name: str,
file_type: Literal["google_doc", "google_sheet"],
content: str | None = None,
) -> dict[str, Any]:
"""Create a new Google Doc or Google Sheet in Google Drive.
Use this tool when the user explicitly asks to create a new document
or spreadsheet in Google Drive.
Args:
name: The file name (without extension).
file_type: Either "google_doc" or "google_sheet".
content: Optional initial content. For google_doc, provide markdown text.
For google_sheet, provide CSV-formatted text.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- file_id: Google Drive file ID (if success)
- name: File name (if success)
- web_view_link: URL to open the file (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Create a Google Doc called 'Meeting Notes'"
- "Create a spreadsheet named 'Budget 2026' with some sample data"
"""
logger.info(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
if file_type not in _MIME_MAP:
return {
"status": "error",
"message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.",
}
try:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
)
approval = interrupt(
{
"type": "google_drive_file_creation",
"action": {
"tool": "create_google_drive_file",
"params": {
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
}
final_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_name = final_params.get("name", name)
final_file_type = final_params.get("file_type", file_type)
final_content = final_params.get("content", content)
final_connector_id = final_params.get("connector_id")
final_parent_folder_id = final_params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
client = GoogleDriveClient(
session=db_session, connector_id=actual_connector_id
)
try:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
}
raise
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Google Drive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the file. Please try again.",
}
return create_google_drive_file

View file

@ -0,0 +1,243 @@
import logging
from typing import Any
from googleapiclient.errors import HttpError
from langchain_core.tools import tool
from langgraph.types import interrupt
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_drive.client import GoogleDriveClient
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_google_drive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def delete_google_drive_file(
file_name: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Move a Google Drive file to trash.
Use this tool when the user explicitly asks to delete, remove, or trash
a file in Google Drive.
Args:
file_name: The exact name of the file to trash (as it appears in Drive).
delete_from_kb: Whether to also remove the file from the knowledge base.
Default is False.
Set to True to remove from both Google Drive and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- file_id: Google Drive file ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the file name or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry this tool.
Examples:
- "Delete the 'Meeting Notes' file from Google Drive"
- "Trash the 'Old Budget' spreadsheet"
"""
logger.info(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
try:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, file_name
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
file = context["file"]
file_id = file["file_id"]
document_id = file.get("document_id")
connector_id_from_context = context["account"]["id"]
if not file_id:
return {
"status": "error",
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
logger.info(
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
)
approval = interrupt(
{
"type": "google_drive_file_trash",
"action": {
"tool": "delete_google_drive_file",
"params": {
"file_id": file_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
},
"context": context,
}
)
decisions_raw = (
approval.get("decisions", []) if isinstance(approval, dict) else []
)
decisions = (
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
)
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
logger.warning("No approval decision received")
return {"status": "error", "message": "No approval decision received"}
decision = decisions[0]
decision_type = decision.get("type") or decision.get("decision_type")
logger.info(f"User decision: {decision_type}")
if decision_type == "reject":
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
edited_action = decision.get("edited_action")
final_params: dict[str, Any] = {}
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
final_params = edited_args
elif isinstance(decision.get("args"), dict):
final_params = decision["args"]
final_file_id = final_params.get("file_id", file_id)
final_connector_id = final_params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
try:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
}
raise
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file['name']}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Google Drive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while trashing the file. Please try again.",
}
return delete_google_drive_file

View file

@ -172,12 +172,52 @@ def _normalize_connectors(
# =============================================================================
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
# Fraction of the model's context window (in characters) that a single tool
# result is allowed to occupy. The remainder is reserved for system prompt,
# conversation history, and model output. With ~4 chars/token this gives a
# tool result ≈ 25 % of the context budget in tokens.
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
_CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
_MAX_CHUNK_CHARS = 8_000
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
Uses ``litellm.get_model_info`` via the value already resolved by
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
chain as ``max_input_tokens``. Falls back to a conservative default when
the value is unavailable.
"""
if max_input_tokens is None or max_input_tokens <= 0:
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
def format_documents_for_context(
documents: list[dict[str, Any]],
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
a single oversized chunk cannot monopolize the output.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
Returns:
Formatted string with document contents and metadata
@ -278,37 +318,57 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
"BAIDU_SEARCH_API",
}
# Render XML expected by citation instructions
# Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = []
for g in grouped.values():
total_chars = 0
total_docs = len(grouped)
for doc_idx, g in enumerate(grouped.values()):
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
is_live_search = g["document_type"] in live_search_connectors
parts.append("<document>")
parts.append("<document_metadata>")
parts.append(f" <document_id>{g['document_id']}</document_id>")
parts.append(f" <document_type>{g['document_type']}</document_type>")
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
parts.append("</document_metadata>")
parts.append("")
parts.append("<document_content>")
doc_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{g['document_id']}</document_id>",
f" <document_type>{g['document_type']}</document_type>",
f" <title><![CDATA[{g['title']}]]></title>",
f" <url><![CDATA[{g['url']}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
"<document_content>",
]
for ch in g["chunks"]:
ch_content = ch["content"]
# For live search connectors, use the document URL as the chunk id
# so the LLM outputs [citation:https://...] which the frontend
# renders as a clickable link.
if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
if ch_id is None:
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
else:
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
doc_lines.append(
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
)
parts.append("</document_content>")
parts.append("</document>")
parts.append("")
doc_lines.extend(["</document_content>", "</document>", ""])
doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
# Always include at least the first document; afterwards enforce budget.
if doc_idx > 0 and total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k "
f"to retrieve different results. -->"
)
break
parts.append(doc_xml)
total_chars += doc_len
return "\n".join(parts).strip()
@ -328,6 +388,7 @@ async def search_knowledge_base_async(
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
Search the user's knowledge base for relevant documents.
@ -345,6 +406,8 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
@ -488,7 +551,8 @@ async def search_knowledge_base_async(
deduplicated.append(doc)
return format_documents_for_context(deduplicated)
output_budget = _compute_tool_output_budget(max_input_tokens)
return format_documents_for_context(deduplicated, max_chars=output_budget)
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
@ -552,6 +616,7 @@ def create_search_knowledge_base_tool(
connector_service: ConnectorService,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> StructuredTool:
"""
Factory function to create the search_knowledge_base tool with injected dependencies.
@ -564,6 +629,8 @@ def create_search_knowledge_base_tool(
Used to dynamically generate the tool docstring.
available_document_types: Optional list of document types that have data in the search space.
Used to inform the LLM about what data exists.
max_input_tokens: Model context window (tokens) from litellm model info.
Used to dynamically size tool output.
Returns:
A configured StructuredTool instance
@ -634,6 +701,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
max_input_tokens=max_input_tokens,
)
# Create StructuredTool with dynamic description

View file

@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
import time
from typing import Any
from langchain_core.tools import StructuredTool
@ -25,6 +26,9 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
def _create_dynamic_input_model_from_schema(
tool_name: str,
@ -355,6 +359,19 @@ async def _load_http_mcp_tools(
return tools
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None)
else:
_mcp_tools_cache.clear()
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
@ -364,6 +381,9 @@ async def load_mcp_tools(
This discovers tools dynamically from MCP servers using the protocol.
Supports both stdio (local process) and HTTP (remote server) transports.
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Args:
session: Database session
search_space_id: User's search space ID
@ -372,8 +392,20 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances
"""
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
search_space_id,
len(cached_tools),
now - cached_at,
)
return list(cached_tools)
try:
# Fetch all MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
@ -385,27 +417,22 @@ async def load_mcp_tools(
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Early validation: Extract and validate connector config
config = connector.config or {}
server_config = config.get("server_config", {})
# Validate server_config exists and is a dict
if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue
# Determine transport type
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
# HTTP-based MCP server
connector_tools = await _load_http_mcp_tools(
connector.id, connector.name, server_config
)
else:
# stdio-based MCP server (default)
connector_tools = await _load_stdio_mcp_tools(
connector.id, connector.name, server_config
)
@ -417,6 +444,7 @@ async def load_mcp_tools(
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
)
_mcp_tools_cache[search_space_id] = (now, tools)
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools

View file

@ -47,6 +47,10 @@ from app.db import ChatVisibility
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .knowledge_base import create_search_knowledge_base_tool
from .linear import (
create_create_linear_issue_tool,
@ -114,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
# Optional: dynamically discovered connectors/document types
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
max_input_tokens=deps.get("max_input_tokens"),
),
requires=["search_space_id", "db_session", "connector_service"],
# Note: available_connectors and available_document_types are optional
@ -292,6 +297,29 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# GOOGLE DRIVE TOOLS - create files, delete files
# =========================================================================
ToolDefinition(
name="create_google_drive_file",
description="Create a new Google Doc or Google Sheet in Google Drive",
factory=lambda deps: create_create_google_drive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_google_drive_file",
description="Move an indexed Google Drive file to trash",
factory=lambda deps: create_delete_google_drive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
]
@ -417,8 +445,18 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
_t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
_perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Load MCP tools if requested and dependencies are available
if (
@ -427,10 +465,16 @@ async def build_tools_async(
and "search_space_id" in dependencies
):
try:
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
_perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(mcp_tools),
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",

View file

@ -14,8 +14,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
from app.utils.document_converters import embed_text
def format_surfsense_docs_results(results: list[tuple]) -> str:
@ -100,7 +100,7 @@ async def search_surfsense_docs_async(
Formatted string with relevant documentation content
"""
# Get embedding for the query
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
# Vector similarity search on chunks, joining with documents
stmt = (

View file

@ -8,8 +8,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, SharedMemory, User
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
@ -64,7 +64,7 @@ async def save_shared_memory(
count = await get_shared_memory_count(db_session, search_space_id)
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
await delete_oldest_shared_memory(db_session, search_space_id)
embedding = config.embedding_model_instance.embed(content)
embedding = embed_text(content)
row = SharedMemory(
search_space_id=search_space_id,
created_by_id=_to_uuid(created_by_id),
@ -108,7 +108,7 @@ async def recall_shared_memory(
if category and category in valid_categories:
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
if query:
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
stmt = stmt.order_by(
SharedMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)

View file

@ -17,8 +17,8 @@ from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, UserMemory
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
@ -178,7 +178,7 @@ def create_save_memory_tool(
await delete_oldest_memory(db_session, user_id, search_space_id)
# Generate embedding for the memory
embedding = config.embedding_model_instance.embed(content)
embedding = embed_text(content)
# Create new memory using ORM
# The pgvector Vector column type handles embedding conversion automatically
@ -268,7 +268,7 @@ def create_recall_memory_tool(
if query:
# Semantic search using embeddings
query_embedding = config.embedding_model_instance.embed(query)
query_embedding = embed_text(query)
# Build query with vector similarity
stmt = (

View file

@ -175,8 +175,39 @@ def rate_limit_password_reset(request: Request):
)
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
This helps pinpoint synchronous code that freezes the entire FastAPI server.
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
"""
import os
if not os.environ.get("PERF_DEBUG"):
return
_slow_log = logging.getLogger("surfsense.perf.slow")
_slow_log.setLevel(logging.WARNING)
if not _slow_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
_slow_log.addHandler(_h)
_slow_log.propagate = False
loop = asyncio.get_running_loop()
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
loop.set_debug(True)
_slow_log.warning(
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
"Set PERF_DEBUG='' to disable.",
threshold_sec,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
_enable_slow_callback_logging(threshold_sec=0.5)
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence

View file

@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -383,6 +383,7 @@ async def _process_gmail_messages_phase2(
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int]:
"""
@ -415,7 +416,7 @@ async def _process_gmail_messages_phase2(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
@ -427,10 +428,8 @@ async def _process_gmail_messages_phase2(
item["markdown_content"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
@ -646,6 +645,7 @@ async def index_composio_gmail(
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)

View file

@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.connectors.composio_connector import ComposioConnector
from app.db import Document, DocumentStatus, DocumentType
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -440,7 +440,7 @@ async def index_composio_google_calendar(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"summary": item["summary"],
@ -456,12 +456,10 @@ async def index_composio_google_calendar(
document_metadata_for_summary,
)
else:
summary_content = f"Calendar: {item['summary']}\n\nStart: {item['start_time']}\nEnd: {item['end_time']}"
if item["location"]:
summary_content += f"\nLocation: {item['location']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
summary_content = (
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -31,6 +31,7 @@ from app.tasks.connector_indexers.base import (
)
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -714,6 +715,7 @@ async def index_composio_google_drive(
max_items=max_items,
task_logger=task_logger,
log_entry=log_entry,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)
else:
@ -747,6 +749,7 @@ async def index_composio_google_drive(
max_items=max_items,
task_logger=task_logger,
log_entry=log_entry,
enable_summary=getattr(connector, "enable_summary", False),
on_heartbeat_callback=on_heartbeat_callback,
)
@ -829,6 +832,7 @@ async def _index_composio_drive_delta_sync(
max_items: int,
task_logger: TaskLoggingService,
log_entry,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index Google Drive files using delta sync with real-time document status updates.
@ -1079,7 +1083,7 @@ async def _index_composio_drive_delta_sync(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"file_id": item["file_id"],
"file_name": item["file_name"],
@ -1090,10 +1094,8 @@ async def _index_composio_drive_delta_sync(
markdown_content, user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(markdown_content)
@ -1155,6 +1157,7 @@ async def _index_composio_drive_full_scan(
max_items: int,
task_logger: TaskLoggingService,
log_entry,
enable_summary: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index Google Drive files using full scan with real-time document status updates."""
@ -1488,7 +1491,7 @@ async def _index_composio_drive_full_scan(
session, user_id, search_space_id
)
if user_llm:
if user_llm and enable_summary:
document_metadata_for_summary = {
"file_id": item["file_id"],
"file_name": item["file_name"],
@ -1499,10 +1502,8 @@ async def _index_composio_drive_full_scan(
markdown_content, user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(markdown_content)

View file

@ -1,12 +1,15 @@
"""Google Drive API client."""
import io
from typing import Any
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload
from sqlalchemy.ext.asyncio import AsyncSession
from .credentials import get_valid_credentials
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
class GoogleDriveClient:
@ -179,3 +182,65 @@ class GoogleDriveClient:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
async def create_file(
self,
name: str,
mime_type: str,
parent_folder_id: str | None = None,
content: str | None = None,
) -> dict[str, Any]:
service = await self.get_service()
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
if parent_folder_id:
body["parents"] = [parent_folder_id]
media: MediaIoBaseUpload | None = None
if content:
if mime_type == GOOGLE_DOC:
import markdown as md_lib
html = md_lib.markdown(content)
media = MediaIoBaseUpload(
io.BytesIO(html.encode("utf-8")),
mimetype="text/html",
resumable=False,
)
elif mime_type == GOOGLE_SHEET:
media = MediaIoBaseUpload(
io.BytesIO(content.encode("utf-8")),
mimetype="text/csv",
resumable=False,
)
if media:
return (
service.files()
.create(
body=body,
media_body=media,
fields="id,name,mimeType,webViewLink",
supportsAllDrives=True,
)
.execute()
)
return (
service.files()
.create(
body=body,
fields="id,name,mimeType,webViewLink",
supportsAllDrives=True,
)
.execute()
)
async def trash_file(self, file_id: str) -> bool:
service = await self.get_service()
service.files().update(
fileId=file_id,
body={"trashed": True},
supportsAllDrives=True,
).execute()
return True

View file

@ -1,6 +1,6 @@
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from enum import Enum
from enum import StrEnum
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
@ -31,7 +31,7 @@ if config.AUTH_TYPE == "GOOGLE":
DATABASE_URL = config.DATABASE_URL
class DocumentType(str, Enum):
class DocumentType(StrEnum):
EXTENSION = "EXTENSION"
CRAWLED_URL = "CRAWLED_URL"
FILE = "FILE"
@ -60,7 +60,7 @@ class DocumentType(str, Enum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
class SearchSourceConnectorType(str, Enum):
class SearchSourceConnectorType(StrEnum):
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
TAVILY_API = "TAVILY_API"
SEARXNG_API = "SEARXNG_API"
@ -93,7 +93,7 @@ class SearchSourceConnectorType(str, Enum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
class PodcastStatus(str, Enum):
class PodcastStatus(StrEnum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"
@ -177,7 +177,7 @@ class DocumentStatus:
return None
class LiteLLMProvider(str, Enum):
class LiteLLMProvider(StrEnum):
"""
Enum for LLM providers supported by LiteLLM.
"""
@ -215,7 +215,7 @@ class LiteLLMProvider(str, Enum):
CUSTOM = "CUSTOM"
class ImageGenProvider(str, Enum):
class ImageGenProvider(StrEnum):
"""
Enum for image generation providers supported by LiteLLM.
This is a subset of LLM providers only those that support image generation.
@ -233,7 +233,7 @@ class ImageGenProvider(str, Enum):
NSCALE = "NSCALE"
class LogLevel(str, Enum):
class LogLevel(StrEnum):
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
@ -241,13 +241,13 @@ class LogLevel(str, Enum):
CRITICAL = "CRITICAL"
class LogStatus(str, Enum):
class LogStatus(StrEnum):
IN_PROGRESS = "IN_PROGRESS"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
class IncentiveTaskType(str, Enum):
class IncentiveTaskType(StrEnum):
"""
Enum for incentive task types that users can complete to earn free pages.
Each task can only be completed once per user.
@ -298,7 +298,7 @@ INCENTIVE_TASKS_CONFIG = {
}
class Permission(str, Enum):
class Permission(StrEnum):
"""
Granular permissions for search space resources.
Use '*' (FULL_ACCESS) to grant all permissions.
@ -471,7 +471,7 @@ class BaseModel(Base):
id = Column(Integer, primary_key=True, index=True)
class NewChatMessageRole(str, Enum):
class NewChatMessageRole(StrEnum):
"""Role enum for new chat messages."""
USER = "user"
@ -479,7 +479,7 @@ class NewChatMessageRole(str, Enum):
SYSTEM = "system"
class ChatVisibility(str, Enum):
class ChatVisibility(StrEnum):
"""
Visibility/sharing level for chat threads.
@ -788,7 +788,7 @@ class ChatSessionState(BaseModel):
ai_responding_to_user = relationship("User")
class MemoryCategory(str, Enum):
class MemoryCategory(StrEnum):
"""Categories for user memories."""
# Using lowercase keys to match PostgreSQL enum values
@ -1317,6 +1317,12 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
config = Column(JSON, nullable=False)
# Summary generation (LLM-based) - disabled by default to save resources.
# When enabled, improves hybrid search quality at the cost of LLM calls.
enable_summary = Column(
Boolean, nullable=False, default=False, server_default="false"
)
# Periodic indexing fields
periodic_indexing_enabled = Column(Boolean, nullable=False, default=False)
indexing_frequency_minutes = Column(Integer, nullable=True)

View file

@ -0,0 +1,47 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
async def index_uploaded_file(
markdown_content: str,
filename: str,
etl_service: str,
search_space_id: int,
user_id: str,
session: AsyncSession,
llm,
should_summarize: bool = False,
) -> None:
connector_doc = ConnectorDocument(
title=filename,
source_markdown=markdown_content,
unique_id=filename,
document_type=DocumentType.FILE,
search_space_id=search_space_id,
created_by_id=user_id,
connector_id=None,
should_summarize=should_summarize,
should_use_code_chunker=False,
fallback_summary=markdown_content[:4000],
metadata={
"FILE_NAME": filename,
"ETL_SERVICE": etl_service,
},
)
service = IndexingPipelineService(session)
documents = await service.prepare_for_indexing([connector_doc])
if not documents:
raise RuntimeError("prepare_for_indexing returned no documents")
indexed = await service.index(documents[0], connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
indexed.content_needs_reindexing = False
await session.commit()

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel, Field, field_validator
from app.db import DocumentType
class ConnectorDocument(BaseModel):
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
title: str
source_markdown: str
unique_id: str
document_type: DocumentType
search_space_id: int = Field(gt=0)
should_summarize: bool = True
should_use_code_chunker: bool = False
fallback_summary: str | None = None
metadata: dict = {}
connector_id: int | None = None
created_by_id: str
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
@classmethod
def not_empty(cls, v: str, info) -> str:
if not v.strip():
raise ValueError(f"{info.field_name} must not be empty or whitespace")
return v

View file

@ -0,0 +1,9 @@
from app.config import config
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts."""
chunker = (
config.code_chunker_instance if use_code_chunker else config.chunker_instance
)
return [c.text for c in chunker.chunk(text)]

View file

@ -0,0 +1,3 @@
from app.utils.document_converters import embed_text
__all__ = ["embed_text"]

View file

@ -0,0 +1,15 @@
import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_content_hash(doc: ConnectorDocument) -> str:
"""Return a SHA-256 hash of the document's content scoped to its search space."""
combined = f"{doc.search_space_id}:{doc.source_markdown}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

View file

@ -0,0 +1,39 @@
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
async def rollback_and_persist_failure(
session: AsyncSession, document: Document, message: str
) -> None:
"""Roll back the current transaction and best-effort persist a failed status.
Called exclusively from except blocks must never raise, or the new exception
would chain with the original and mask it entirely.
"""
try:
await session.rollback()
except Exception:
return # Session is completely dead; nothing further we can do.
try:
await session.refresh(document)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
pass # Best-effort; document will be retried on the next sync.
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)

View file

@ -0,0 +1,30 @@
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import optimize_content_for_context_window
async def summarize_document(
source_markdown: str, llm, metadata: dict | None = None
) -> str:
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
model_name = getattr(llm, "model", "gpt-3.5-turbo")
optimized_content = optimize_content_for_context_window(
source_markdown, metadata, model_name
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
content_with_metadata = (
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
)
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
summary_content = summary_result.content
if metadata:
metadata_parts = ["# DOCUMENT METADATA"]
for key, value in metadata.items():
if value:
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
metadata_section = "\n".join(metadata_parts)
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
return summary_content

View file

@ -0,0 +1,146 @@
from litellm.exceptions import (
APIConnectionError,
APIResponseValidationError,
AuthenticationError,
BadGatewayError,
BadRequestError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
ServiceUnavailableError,
Timeout,
UnprocessableEntityError,
)
from sqlalchemy.exc import IntegrityError as IntegrityError
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
RateLimitError,
Timeout,
ServiceUnavailableError,
BadGatewayError,
InternalServerError,
APIConnectionError,
)
PERMANENT_LLM_ERRORS = (
AuthenticationError,
PermissionDeniedError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
APIResponseValidationError,
)
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
)
class PipelineMessages:
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = (
"Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = (
"Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
def safe_exception_message(exc: Exception) -> str:
try:
return str(exc)
except Exception:
return "Something went wrong during indexing. Error details could not be retrieved."
def llm_retryable_message(exc: Exception) -> str:
try:
if isinstance(exc, RateLimitError):
return PipelineMessages.RATE_LIMIT
if isinstance(exc, Timeout):
return PipelineMessages.LLM_TIMEOUT
if isinstance(exc, ServiceUnavailableError):
return PipelineMessages.LLM_UNAVAILABLE
if isinstance(exc, BadGatewayError):
return PipelineMessages.LLM_BAD_GATEWAY
if isinstance(exc, InternalServerError):
return PipelineMessages.LLM_SERVER_ERROR
if isinstance(exc, APIConnectionError):
return PipelineMessages.LLM_CONNECTION
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def llm_permanent_message(exc: Exception) -> str:
try:
if isinstance(exc, AuthenticationError):
return PipelineMessages.LLM_AUTH
if isinstance(exc, PermissionDeniedError):
return PipelineMessages.LLM_PERMISSION
if isinstance(exc, NotFoundError):
return PipelineMessages.LLM_NOT_FOUND
if isinstance(exc, BadRequestError):
return PipelineMessages.LLM_BAD_REQUEST
if isinstance(exc, UnprocessableEntityError):
return PipelineMessages.LLM_UNPROCESSABLE
if isinstance(exc, APIResponseValidationError):
return PipelineMessages.LLM_RESPONSE
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def embedding_message(exc: Exception) -> str:
try:
if isinstance(exc, RuntimeError):
return PipelineMessages.EMBEDDING_FAILED
if isinstance(exc, OSError):
return PipelineMessages.EMBEDDING_MODEL
if isinstance(exc, MemoryError):
return PipelineMessages.EMBEDDING_MEMORY
return safe_exception_message(exc)
except Exception:
return "Something went wrong when generating the embedding."

View file

@ -0,0 +1,237 @@
import contextlib
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_text
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
rollback_and_persist_failure,
)
from app.indexing_pipeline.document_summarizer import summarize_document
from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS,
PERMANENT_LLM_ERRORS,
RETRYABLE_LLM_ERRORS,
PipelineMessages,
embedding_message,
llm_permanent_message,
llm_retryable_message,
safe_exception_message,
)
from app.indexing_pipeline.pipeline_logger import (
PipelineLogContext,
log_batch_aborted,
log_chunking_overflow,
log_doc_skipped_unknown,
log_document_queued,
log_document_requeued,
log_document_updated,
log_embedding_error,
log_index_started,
log_index_success,
log_permanent_llm_error,
log_race_condition,
log_retryable_llm_error,
log_unexpected_error,
)
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
connector_id=connector_docs[0].connector_id if connector_docs else 0,
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
unique_id="batch",
)
for connector_doc in connector_docs:
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
)
try:
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
content_hash = compute_content_hash(connector_doc)
if unique_identifier_hash in seen_hashes:
continue
seen_hashes.add(unique_identifier_hash)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == unique_identifier_hash
)
)
existing = result.scalars().first()
if existing is not None:
if existing.content_hash == content_hash:
if existing.title != connector_doc.title:
existing.title = connector_doc.title
existing.updated_at = datetime.now(UTC)
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.pending()
existing.updated_at = datetime.now(UTC)
documents.append(existing)
log_document_requeued(ctx)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
existing.document_metadata = connector_doc.metadata
existing.updated_at = datetime.now(UTC)
existing.status = DocumentStatus.pending()
documents.append(existing)
log_document_updated(ctx)
continue
duplicate = await self.session.execute(
select(Document).filter(Document.content_hash == content_hash)
)
if duplicate.scalars().first() is not None:
continue
document = Document(
title=connector_doc.title,
document_type=connector_doc.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=connector_doc.source_markdown,
document_metadata=connector_doc.metadata,
search_space_id=connector_doc.search_space_id,
connector_id=connector_doc.connector_id,
created_by_id=connector_doc.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
documents.append(document)
log_document_queued(ctx)
except Exception as e:
log_doc_skipped_unknown(ctx, e)
try:
await self.session.commit()
return documents
except IntegrityError:
# A concurrent worker committed a document with the same content_hash
# or unique_identifier_hash between our check and our INSERT.
# The document already exists — roll back and let the next sync run handle it.
log_race_condition(batch_ctx)
await self.session.rollback()
return []
except Exception as e:
log_batch_aborted(batch_ctx, e)
await self.session.rollback()
return []
async def index(
self, document: Document, connector_doc: ConnectorDocument, llm
) -> Document:
"""
Run summarization, embedding, and chunking for a document and persist the results.
"""
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
embedding = embed_text(content)
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
document.content = content
document.embedding = embedding
attach_chunks_to_document(document, chunks)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_retryable_message(e)
)
except PERMANENT_LLM_ERRORS as e:
log_permanent_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_permanent_message(e)
)
except RecursionError as e:
log_chunking_overflow(ctx, e)
await rollback_and_persist_failure(
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
)
except EMBEDDING_ERRORS as e:
log_embedding_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, embedding_message(e)
)
except Exception as e:
log_unexpected_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, safe_exception_message(e)
)
with contextlib.suppress(Exception):
await self.session.refresh(document)
return document

View file

@ -0,0 +1,126 @@
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class PipelineLogContext:
connector_id: int | None
search_space_id: int
unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only)
class LogMessages:
# prepare_for_indexing
DOCUMENT_QUEUED = "New document queued for indexing."
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
# index
INDEX_STARTED = "Document indexing started."
INDEX_SUCCESS = "Document indexed successfully."
LLM_RETRYABLE = (
"Retryable LLM error — document marked failed, will retry on next sync."
)
LLM_PERMANENT = "Permanent LLM error — document marked failed."
EMBEDDING_FAILED = "Embedding error — document marked failed."
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed."
def _format_context(ctx: PipelineLogContext) -> str:
parts = [
f"connector_id={ctx.connector_id}",
f"search_space_id={ctx.search_space_id}",
f"unique_id={ctx.unique_id}",
]
if ctx.doc_id is not None:
parts.append(f"doc_id={ctx.doc_id}")
return " ".join(parts)
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
try:
parts = [msg, _format_context(ctx)]
for key, val in extra.items():
parts.append(f"{key}={val}")
return " ".join(parts)
except Exception:
return msg
def _safe_log(
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
) -> None:
# Logging must never raise — a broken log call inside an except block would
# chain with the original exception and mask it entirely.
try:
message = _build_message(msg, ctx, **extra)
level_fn(message, exc_info=exc_info)
except Exception:
pass
# ── prepare_for_indexing ──────────────────────────────────────────────────────
def log_document_queued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
def log_document_updated(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
def log_document_requeued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
)
def log_race_condition(ctx: PipelineLogContext) -> None:
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
# ── index ─────────────────────────────────────────────────────────────────────
def log_index_started(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)

View file

@ -36,6 +36,7 @@ from .podcasts_routes import router as podcasts_router
from .public_chat_routes import router as public_chat_router
from .rbac_routes import router as rbac_router
from .reports_routes import router as reports_router
from .sandbox_routes import router as sandbox_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
from .slack_add_connector_route import router as slack_add_connector_router
@ -50,6 +51,7 @@ router.include_router(editor_router)
router.include_router(documents_router)
router.include_router(notes_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
router.include_router(chat_comments_router)
router.include_router(podcasts_router) # Podcast task status and audio
router.include_router(reports_router) # Report CRUD and export (PDF/DOCX)

View file

@ -28,6 +28,7 @@ from app.schemas import (
DocumentWithChunksRead,
PaginatedResponse,
)
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -44,6 +45,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
router = APIRouter()
MAX_FILES_PER_UPLOAD = 10
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
@router.post("/documents")
async def create_documents(
@ -114,8 +119,10 @@ async def create_documents(
async def create_documents_file_upload(
files: list[UploadFile],
search_space_id: int = Form(...),
should_summarize: bool = Form(False),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
):
"""
Upload files as documents with real-time status tracking.
@ -148,12 +155,37 @@ async def create_documents_file_upload(
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > MAX_FILES_PER_UPLOAD:
raise HTTPException(
status_code=413,
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
)
total_size = 0
for file in files:
file_size = file.size or 0
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
total_size += file_size
if total_size > MAX_TOTAL_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
created_documents: list[Document] = []
files_to_process: list[
tuple[Document, str, str]
] = [] # (document, temp_path, filename)
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
actual_total_size = 0
# ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status
@ -169,11 +201,28 @@ async def create_documents_file_upload(
temp_path = temp_file.name
content = await file.read()
file_size = len(content)
if file_size > MAX_FILE_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
actual_total_size += file_size
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
with open(temp_path, "wb") as f:
f.write(content)
file_size = len(content)
# Generate unique identifier for deduplication check
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, file.filename or "unknown", search_space_id
@ -244,19 +293,16 @@ async def create_documents_file_upload(
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch Celery tasks for each file =====
# ===== PHASE 2: Dispatch tasks for each file =====
# Each task will update document status: pending → processing → ready/failed
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
for document, temp_path, filename in files_to_process:
process_file_upload_with_document_task.delay(
await dispatcher.dispatch_file_processing(
document_id=document.id,
temp_path=temp_path,
filename=filename,
search_space_id=search_space_id,
user_id=str(user.id),
should_summarize=should_summarize,
)
return {
@ -373,10 +419,11 @@ async def read_documents(
# Convert database objects to API-friendly format
api_documents = []
for doc in db_documents:
# Get user name (display_name or email fallback)
created_by_name = None
created_by_email = None
if doc.created_by:
created_by_name = doc.created_by.display_name or doc.created_by.email
created_by_name = doc.created_by.display_name
created_by_email = doc.created_by.email
# Parse status from JSONB
status_data = None
@ -400,6 +447,7 @@ async def read_documents(
search_space_id=doc.search_space_id,
created_by_id=doc.created_by_id,
created_by_name=created_by_name,
created_by_email=created_by_email,
status=status_data,
)
)
@ -528,10 +576,11 @@ async def search_documents(
# Convert database objects to API-friendly format
api_documents = []
for doc in db_documents:
# Get user name (display_name or email fallback)
created_by_name = None
created_by_email = None
if doc.created_by:
created_by_name = doc.created_by.display_name or doc.created_by.email
created_by_name = doc.created_by.display_name
created_by_email = doc.created_by.email
# Parse status from JSONB
status_data = None
@ -555,6 +604,7 @@ async def search_documents(
search_space_id=doc.search_space_id,
created_by_id=doc.created_by_id,
created_by_name=created_by_name,
created_by_email=created_by_email,
status=status_data,
)
)

View file

@ -76,9 +76,9 @@ def get_token_encryption() -> TokenEncryption:
# Google Drive OAuth scopes
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive
"https://www.googleapis.com/auth/userinfo.email", # User email
"https://www.googleapis.com/auth/userinfo.profile", # User profile
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"openid",
]
@ -151,6 +151,75 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
) from e
@router.get("/auth/google/drive/connector/reauth")
async def reauth_drive(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
"""
Initiate Google Drive re-authentication to upgrade OAuth scopes.
Query params:
space_id: Search space ID the connector belongs to
connector_id: ID of the existing connector to re-authenticate
Returns:
JSON with auth_url to redirect user to Google authorization
"""
try:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(
status_code=404,
detail="Google Drive connector not found or access denied",
)
if not config.SECRET_KEY:
raise HTTPException(
status_code=500, detail="SECRET_KEY not configured for OAuth security."
)
flow = get_google_flow()
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
auth_url, _ = flow.authorization_url(
access_type="offline",
prompt="consent",
include_granted_scopes="true",
state=state_encoded,
)
logger.info(
f"Initiating Google Drive re-auth for user {user.id}, connector {connector_id}"
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate Google Drive re-auth: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate Google re-auth: {e!s}"
) from e
@router.get("/auth/google/drive/connector/callback")
async def drive_callback(
request: Request,
@ -214,6 +283,8 @@ async def drive_callback(
user_id = UUID(data["user_id"])
space_id = data["space_id"]
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
logger.info(
f"Processing Google Drive callback for user {user_id}, space {space_id}"
@ -253,7 +324,45 @@ async def drive_callback(
# Mark that credentials are encrypted for backward compatibility
creds_dict["_token_encrypted"] = True
# Check for duplicate connector (same account already connected)
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
existing_start_page_token = db_connector.config.get("start_page_token")
db_connector.config = {
**creds_dict,
"start_page_token": existing_start_page_token,
}
from sqlalchemy.orm.attributes import flag_modified
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Re-authenticated Google Drive connector {db_connector.id} for user {user_id}"
)
if reauth_return_url and reauth_return_url.startswith("/"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
)
is_duplicate = await check_duplicate_connector(
session,
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,

View file

@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- POST /threads/{thread_id}/messages - Append message
"""
import asyncio
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
@ -52,9 +54,50 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
router = APIRouter()
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
delete_local_sandbox_files,
delete_sandbox,
is_sandbox_enabled,
)
if not is_sandbox_enabled():
return
async def _bg() -> None:
try:
await delete_sandbox(thread_id)
except Exception:
_logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try:
delete_local_sandbox_files(thread_id)
except Exception:
_logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def check_thread_access(
session: AsyncSession,
thread: NewChatThread,
@ -648,6 +691,9 @@ async def delete_thread(
await session.delete(db_thread)
await session.commit()
_try_delete_sandbox(thread_id)
return {"message": "Thread deleted successfully"}
except HTTPException:

View file

@ -17,7 +17,7 @@ import logging
import os
import re
import tempfile
from enum import Enum
from enum import StrEnum
import pypandoc
import typst
@ -46,7 +46,7 @@ router = APIRouter()
MAX_REPORT_LIST_LIMIT = 500
class ExportFormat(str, Enum):
class ExportFormat(StrEnum):
PDF = "pdf"
DOCX = "docx"

View file

@ -0,0 +1,105 @@
"""Routes for downloading files from Daytona sandbox environments."""
from __future__ import annotations
import asyncio
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import NewChatThread, Permission, User, get_async_session
from app.users import current_active_user
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
router = APIRouter()
MIME_TYPES: dict[str, str] = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".svg": "image/svg+xml",
".pdf": "application/pdf",
".csv": "text/csv",
".json": "application/json",
".txt": "text/plain",
".html": "text/html",
".md": "text/markdown",
".py": "text/x-python",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".zip": "application/zip",
}
def _guess_media_type(filename: str) -> str:
ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else ""
return MIME_TYPES.get(ext, "application/octet-stream")
@router.get("/threads/{thread_id}/sandbox/download")
async def download_sandbox_file(
thread_id: int,
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Download a file from the Daytona sandbox associated with a chat thread."""
from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled
if not is_sandbox_enabled():
raise HTTPException(status_code=404, detail="Sandbox is not enabled")
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to access files in this thread",
)
from app.agents.new_chat.sandbox import get_local_sandbox_file
# Prefer locally-persisted copy (sandbox may already be deleted)
local_content = get_local_sandbox_file(thread_id, path)
if local_content is not None:
filename = path.rsplit("/", 1)[-1] if "/" in path else path
media_type = _guess_media_type(filename)
return Response(
content=local_content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
# Fall back to live sandbox download
try:
sandbox = await get_or_create_sandbox(thread_id)
raw_sandbox = sandbox._sandbox
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
except Exception as exc:
logger.warning("Sandbox file download failed for %s: %s", path, exc)
raise HTTPException(
status_code=404, detail=f"Could not download file: {exc}"
) from exc
filename = path.rsplit("/", 1)[-1] if "/" in path else path
media_type = _guess_media_type(filename)
return Response(
content=content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

View file

@ -2735,7 +2735,10 @@ async def create_mcp_connector(
f"for user {user.id} in search space {search_space_id}"
)
# Convert to read schema
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
logger.info(f"Updated MCP connector {connector_id}")
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(connector.search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
"You don't have permission to delete this connector",
)
search_space_id = connector.search_space_id
await session.delete(connector)
await session.commit()
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
logger.info(f"Deleted MCP connector {connector_id}")
except HTTPException:

View file

@ -60,9 +60,8 @@ class DocumentRead(BaseModel):
updated_at: datetime | None
search_space_id: int
created_by_id: UUID | None = None # User who created/uploaded this document
created_by_name: str | None = (
None # Display name or email of the user who created this document
)
created_by_name: str | None = None
created_by_email: str | None = None
status: DocumentStatusSchema | None = (
None # Processing status (ready, processing, failed)
)

View file

@ -1,13 +1,13 @@
"""Podcast schemas for API responses."""
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel
class PodcastStatusEnum(str, Enum):
class PodcastStatusEnum(StrEnum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"

View file

@ -16,6 +16,7 @@ class SearchSourceConnectorBase(BaseModel):
is_indexable: bool
last_indexed_at: datetime | None = None
config: dict[str, Any]
enable_summary: bool = False
periodic_indexing_enabled: bool = False
indexing_frequency_minutes: int | None = None
next_scheduled_at: datetime | None = None
@ -65,6 +66,7 @@ class SearchSourceConnectorUpdate(BaseModel):
is_indexable: bool | None = None
last_indexed_at: datetime | None = None
config: dict[str, Any] | None = None
enable_summary: bool | None = None
periodic_indexing_enabled: bool | None = None
indexing_frequency_minutes: int | None = None
next_scheduled_at: datetime | None = None

View file

@ -1303,10 +1303,9 @@ class ConnectorService:
sources_list = self._build_chunk_sources_from_documents(
github_docs,
description_fn=lambda chunk, _doc_info, metadata: metadata.get(
"description"
)
or chunk.get("content", ""),
description_fn=lambda chunk, _doc_info, metadata: (
metadata.get("description") or chunk.get("content", "")
),
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
)

View file

@ -0,0 +1,11 @@
from app.services.google_drive.tool_metadata_service import (
GoogleDriveAccount,
GoogleDriveFile,
GoogleDriveToolMetadataService,
)
__all__ = [
"GoogleDriveAccount",
"GoogleDriveFile",
"GoogleDriveToolMetadataService",
]

View file

@ -0,0 +1,149 @@
from dataclasses import dataclass
from sqlalchemy import and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
@dataclass
class GoogleDriveAccount:
id: int
name: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "GoogleDriveAccount":
return cls(id=connector.id, name=connector.name)
def to_dict(self) -> dict:
return {"id": self.id, "name": self.name}
@dataclass
class GoogleDriveFile:
file_id: str
name: str
mime_type: str
web_view_link: str
connector_id: int
document_id: int
@classmethod
def from_document(cls, document: Document) -> "GoogleDriveFile":
meta = document.document_metadata or {}
return cls(
file_id=meta.get("google_drive_file_id", ""),
name=meta.get("google_drive_file_name", document.title),
mime_type=meta.get("google_drive_mime_type", ""),
web_view_link=meta.get("web_view_link", ""),
connector_id=document.connector_id,
document_id=document.id,
)
def to_dict(self) -> dict:
return {
"file_id": self.file_id,
"name": self.name,
"mime_type": self.mime_type,
"web_view_link": self.web_view_link,
"connector_id": self.connector_id,
"document_id": self.document_id,
}
class GoogleDriveToolMetadataService:
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
if not accounts:
return {
"accounts": [],
"supported_types": [],
"error": "No Google Drive account connected",
}
return {
"accounts": [acc.to_dict() for acc in accounts],
"supported_types": ["google_doc", "google_sheet"],
}
async def get_trash_context(
self, search_space_id: int, user_id: str, file_name: str
) -> dict:
result = await self._db_session.execute(
select(Document)
.join(
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
)
document = result.scalars().first()
if not document:
return {
"error": (
f"File '{file_name}' not found in your indexed Google Drive files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
)
}
if not document.connector_id:
return {"error": "Document has no associated connector"}
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
)
connector = result.scalars().first()
if not connector:
return {"error": "Connector not found or access denied"}
account = GoogleDriveAccount.from_connector(connector)
file = GoogleDriveFile.from_document(document)
return {
"account": account.to_dict(),
"file": file.to_dict(),
}
async def _get_google_drive_accounts(
self, search_space_id: int, user_id: str
) -> list[GoogleDriveAccount]:
result = await self._db_session.execute(
select(SearchSourceConnector)
.filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
)
)
.order_by(SearchSourceConnector.last_indexed_at.desc())
)
connectors = result.scalars().all()
return [GoogleDriveAccount.from_connector(c) for c in connectors]

View file

@ -4,12 +4,12 @@ from datetime import datetime
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.linear_connector import LinearConnector
from app.db import Chunk, Document
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
)
@ -80,7 +80,7 @@ class LinearKBSyncService:
state = formatted_issue.get("state", "Unknown")
priority = issue_raw.get("priorityLabel", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
description = formatted_issue.get("description", "")
formatted_issue.get("description", "")
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
@ -100,18 +100,10 @@ class LinearKBSyncService:
issue_content, user_llm, document_metadata_for_summary
)
else:
if description and len(description) > 1000:
description = description[:997] + "..."
summary_content = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n"
)
if description:
summary_content += f"Description: {description}\n\n"
summary_content += f"Comments: {comment_count}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = embed_text(summary_content)
await self.db_session.execute(
delete(Chunk).where(Chunk.document_id == document.id)

View file

@ -12,16 +12,35 @@ synchronous ChatLiteLLM-like interface and async methods.
"""
import logging
import re
from typing import Any
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from litellm import Router
from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
r"maximum context length|token.{0,20}(limit|exceed)|"
r"too many tokens|reduce the length)",
re.IGNORECASE,
)
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
"""Check if a BadRequestError is actually a context window overflow."""
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
@ -234,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel):
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
making it a drop-in replacement for auto-mode routing.
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
"""
# Use model_config for Pydantic v2 compatibility
@ -265,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel):
"""
try:
super().__init__(**kwargs)
# Store router and tools as private attributes
resolved_router = router or LLMRouterService.get_router()
object.__setattr__(self, "_router", resolved_router)
object.__setattr__(self, "_bound_tools", bound_tools)
@ -274,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel):
raise ValueError(
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
computed_profile = self._compute_min_context_profile()
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
logger.info(
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
)
@ -281,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel):
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise
def _compute_min_context_profile(self) -> dict | None:
"""Derive a profile dict with max_input_tokens from router deployments.
Uses litellm.get_model_info to look up each deployment's context window
and picks the *minimum* so that summarization triggers before ANY model
in the pool overflows.
"""
from litellm import get_model_info
if not self._router:
return None
min_ctx: int | None = None
for deployment in self._router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if (
isinstance(ctx, int)
and ctx > 0
and (min_ctx is None or ctx < min_ctx)
):
min_ctx = ctx
except Exception:
continue
if min_ctx is not None:
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
return {"max_input_tokens": min_ctx}
return None
@property
def _llm_type(self) -> str:
return "litellm-router"
@ -359,13 +420,19 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router completion
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
@ -396,13 +463,19 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router async completion
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
@ -432,14 +505,20 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router completion with streaming
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks
for chunk in response:
@ -471,14 +550,20 @@ class ChatLiteLLMRouter(BaseChatModel):
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
# Call router async completion with streaming
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks asynchronously
async for chunk in response:

View file

@ -4,11 +4,11 @@ from datetime import datetime
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import Chunk, Document
from app.services.llm_service import get_user_long_context_llm
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
)
@ -127,10 +127,8 @@ class NotionKBSyncService:
logger.debug(f"Generated summary length: {len(summary_content)} chars")
else:
logger.warning("No LLM configured - using fallback summary")
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content[:500]}..."
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content}"
summary_embedding = embed_text(summary_content)
logger.debug(f"Deleting old chunks for document {document_id}")
await self.db_session.execute(

View file

@ -0,0 +1,53 @@
"""Task dispatcher abstraction for background document processing.
Decouples the upload endpoint from Celery so tests can swap in a
synchronous (inline) implementation that needs only PostgreSQL.
"""
from __future__ import annotations
from typing import Protocol
class TaskDispatcher(Protocol):
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
) -> None: ...
class CeleryTaskDispatcher:
"""Production dispatcher — fires Celery tasks via Redis broker."""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
process_file_upload_with_document_task.delay(
document_id=document_id,
temp_path=temp_path,
filename=filename,
search_space_id=search_space_id,
user_id=user_id,
should_summarize=should_summarize,
)
async def get_task_dispatcher() -> TaskDispatcher:
return CeleryTaskDispatcher()

View file

@ -626,6 +626,7 @@ def process_file_upload_with_document_task(
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
):
"""
Celery task to process uploaded file with existing pending document.
@ -640,6 +641,7 @@ def process_file_upload_with_document_task(
filename: Original filename
search_space_id: ID of the search space
user_id: ID of the user
should_summarize: Whether to generate an LLM summary
"""
import traceback
@ -674,7 +676,12 @@ def process_file_upload_with_document_task(
try:
loop.run_until_complete(
_process_file_with_document(
document_id, temp_path, filename, search_space_id, user_id
document_id,
temp_path,
filename,
search_space_id,
user_id,
should_summarize=should_summarize,
)
)
logger.info(
@ -710,6 +717,7 @@ async def _process_file_with_document(
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
):
"""
Process file and update existing pending document status.
@ -811,6 +819,7 @@ async def _process_file_with_document(
task_logger=task_logger,
log_entry=log_entry,
notification=notification,
should_summarize=should_summarize,
)
# Update notification on success

View file

@ -9,17 +9,21 @@ Supports loading LLM configurations from:
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
"""
import asyncio
import json
import logging
import re
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import logging
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
@ -30,7 +34,20 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
)
from app.db import (
ChatVisibility,
Document,
NewChatMessage,
NewChatThread,
Report,
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import (
clear_ai_responding,
@ -40,6 +57,16 @@ from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
if not _perf_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
_perf_log.addHandler(_h)
_perf_log.propagate = False
_background_tasks: set[asyncio.Task] = set()
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
"""
@ -187,6 +214,7 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
async def _stream_agent_events(
@ -404,6 +432,21 @@ async def _stream_agent_events(
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "execute":
cmd = (
tool_input.get("command", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_cmd = cmd[:80] + ("" if len(cmd) > 80 else "")
last_active_step_title = "Running command"
last_active_step_items = [f"$ {display_cmd}"]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Running command",
status="in_progress",
items=last_active_step_items,
)
else:
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
last_active_step_items = []
@ -620,6 +663,32 @@ async def _stream_agent_events(
status="completed",
items=completed_items,
)
elif tool_name == "execute":
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
exit_code_val = int(m.group(1)) if m else None
if exit_code_val is not None and exit_code_val == 0:
completed_items = [
*last_active_step_items,
"Completed successfully",
]
elif exit_code_val is not None:
completed_items = [
*last_active_step_items,
f"Exit code: {exit_code_val}",
]
else:
completed_items = [*last_active_step_items, "Finished"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Running command",
status="completed",
items=completed_items,
)
elif tool_name == "ls":
if isinstance(tool_output, dict):
ls_output = tool_output.get("result", "")
@ -804,6 +873,8 @@ async def _stream_agent_events(
"create_linear_issue",
"update_linear_issue",
"delete_linear_issue",
"create_google_drive_file",
"delete_google_drive_file",
):
yield streaming_service.format_tool_output_available(
tool_call_id,
@ -811,6 +882,36 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
elif tool_name == "execute":
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
exit_code: int | None = None
output_text = raw_text
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
if m:
exit_code = int(m.group(1))
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
output_text = om.group(1) if om else ""
thread_id_str = config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
fpath = sf_match.group(1).strip()
if fpath and fpath not in result.sandbox_files:
result.sandbox_files.append(fpath)
yield streaming_service.format_tool_output_available(
tool_call_id,
{
"exit_code": exit_code,
"output": output_text,
"thread_id": thread_id_str,
},
)
else:
yield streaming_service.format_tool_output_available(
tool_call_id,
@ -879,6 +980,38 @@ async def _stream_agent_events(
yield streaming_service.format_interrupt_request(result.interrupt_value)
def _try_persist_and_delete_sandbox(
thread_id: int,
sandbox_files: list[str],
) -> None:
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if not is_sandbox_enabled():
return
async def _run() -> None:
try:
await persist_and_delete_sandbox(thread_id, sandbox_files)
except Exception:
logging.getLogger(__name__).warning(
"persist_and_delete_sandbox failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_run())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def stream_new_chat(
user_query: str,
search_space_id: int,
@ -915,6 +1048,8 @@ async def stream_new_chat(
str: SSE formatted response strings
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
try:
# Mark AI as responding to this user for live collaboration
@ -923,6 +1058,7 @@ async def stream_new_chat(
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config(
@ -953,6 +1089,11 @@ async def stream_new_chat(
llm = create_chat_litellm_from_config(llm_config)
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
@ -960,22 +1101,45 @@ async def stream_new_chat(
return
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
# Get Firecrawl API key from webcrawler connector if configured
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -987,20 +1151,22 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
# Build input with message history
langchain_messages = []
_t0 = time.perf_counter()
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=visibility
)
# Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
@ -1012,11 +1178,9 @@ async def stream_new_chat(
# Fetch mentioned documents if any (with chunks for proper citations)
mentioned_documents: list[Document] = []
if mentioned_document_ids:
from sqlalchemy.orm import selectinload as doc_selectinload
result = await session.execute(
select(Document)
.options(doc_selectinload(Document.chunks))
.options(selectinload(Document.chunks))
.filter(
Document.id.in_(mentioned_document_ids),
Document.search_space_id == search_space_id,
@ -1027,8 +1191,6 @@ async def stream_new_chat(
# Fetch mentioned SurfSense docs if any
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
from sqlalchemy.orm import selectinload
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
@ -1112,6 +1274,11 @@ async def stream_new_chat(
"search_space_id": search_space_id,
}
_perf_log.info(
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
time.perf_counter() - _t0,
)
# All pre-streaming DB reads are done. Commit to release the
# transaction and its ACCESS SHARE locks so we don't block DDL
# (e.g. migrations) for the entire duration of LLM streaming.
@ -1119,6 +1286,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
# Configure LangGraph with thread_id for memory
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
configurable = {"thread_id": str(chat_id)}
@ -1180,7 +1353,8 @@ async def stream_new_chat(
items=initial_items,
)
stream_result = StreamResult()
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1192,8 +1366,23 @@ async def stream_new_chat(
initial_step_title=initial_title,
initial_step_items=initial_items,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
@ -1202,12 +1391,6 @@ async def stream_new_chat(
accumulated_text = stream_result.accumulated_text
# Generate LLM title for new chats after first response
# Check if this is the first assistant response by counting existing assistant messages
from sqlalchemy import func
from app.db import NewChatMessage, NewChatThread
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
NewChatMessage.thread_id == chat_id,
@ -1294,6 +1477,8 @@ async def stream_new_chat(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
async def stream_resume_chat(
chat_id: int,
@ -1305,12 +1490,15 @@ async def stream_resume_chat(
thread_visibility: ChatVisibility | None = None,
) -> AsyncGenerator[str, None]:
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
agent_config = await load_agent_config(
session=session,
@ -1334,26 +1522,54 @@ async def stream_resume_chat(
return
llm = create_chat_litellm_from_config(llm_config)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
yield streaming_service.format_done()
return
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -1365,11 +1581,21 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
from langgraph.types import Command
config = {
@ -1380,7 +1606,8 @@ async def stream_resume_chat(
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
stream_result = StreamResult()
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1389,7 +1616,20 @@ async def stream_resume_chat(
result=stream_result,
step_prefix="thinking-resume",
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
@ -1423,3 +1663,5 @@ async def stream_resume_chat(
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)

View file

@ -12,13 +12,13 @@ from collections.abc import Awaitable, Callable
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.airtable_history import AirtableHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -399,7 +399,7 @@ async def index_airtable_records(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"record_id": item["record_id"],
"created_time": item["record"].get("CREATED_TIME()", ""),
@ -415,11 +415,8 @@ async def index_airtable_records(
document_metadata_for_summary,
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Airtable Record: {item['record_id']}\n\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Airtable Record: {item['record_id']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -13,13 +13,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.bookstack_connector import BookStackConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -403,7 +403,7 @@ async def index_bookstack_pages(
"connector_id": connector_id,
}
if user_llm:
if user_llm and connector.enable_summary:
summary_metadata = {
"page_name": item["page_name"],
"page_id": item["page_id"],
@ -418,17 +418,8 @@ async def index_bookstack_pages(
item["full_content"], user_llm, summary_metadata
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n"
if item["page_content"]:
# Take first 1000 characters of content for summary
content_preview = item["page_content"][:1000]
if len(item["page_content"]) > 1000:
content_preview += "..."
summary_content += f"Content Preview: {content_preview}\n\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content
chunks = await create_document_chunks(item["full_content"])

View file

@ -14,13 +14,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.clickup_history import ClickUpHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -398,7 +398,7 @@ async def index_clickup_tasks(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"task_id": item["task_id"],
"task_name": item["task_name"],
@ -418,9 +418,7 @@ async def index_clickup_tasks(
)
else:
summary_content = item["task_content"]
summary_embedding = config.embedding_model_instance.embed(
item["task_content"]
)
summary_embedding = embed_text(item["task_content"])
chunks = await create_document_chunks(item["task_content"])

View file

@ -14,13 +14,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -378,7 +378,7 @@ async def index_confluence_pages(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
@ -394,18 +394,8 @@ async def index_confluence_pages(
item["full_content"], user_llm, document_metadata
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n"
if item["page_content"]:
# Take first 1000 characters of content for summary
content_preview = item["page_content"][:1000]
if len(item["page_content"]) > 1000:
content_preview += "..."
summary_content += f"Content Preview: {content_preview}\n\n"
summary_content += f"Comments: {item['comment_count']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content with comments
chunks = await create_document_chunks(item["full_content"])

View file

@ -23,6 +23,7 @@ from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnector
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_unique_identifier_hash,
)
@ -669,9 +670,7 @@ async def index_discord_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = config.embedding_model_instance.embed(
item["combined_document_string"]
)
doc_embedding = embed_text(item["combined_document_string"])
# Update document to READY with actual content
document.title = f"{item['guild_name']}#{item['channel_name']}"

View file

@ -16,13 +16,13 @@ from datetime import UTC, datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.github_connector import GitHubConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -367,7 +367,7 @@ async def index_github_repos(
"estimated_tokens": digest.estimated_tokens,
}
if user_llm:
if user_llm and connector.enable_summary:
# Prepare content for summarization
summary_content = digest.full_digest
if len(summary_content) > MAX_DIGEST_CHARS:
@ -381,15 +381,12 @@ async def index_github_repos(
summary_content, user_llm, document_metadata_for_summary
)
else:
# Fallback to simple summary if no LLM configured
summary_text = (
f"# GitHub Repository: {repo_full_name}\n\n"
f"## Summary\n{digest.summary}\n\n"
f"## File Structure\n{digest.tree[:3000]}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_text
f"## File Structure\n{digest.tree}"
)
summary_embedding = embed_text(summary_text)
# Chunk the full digest content for granular search
try:
@ -551,7 +548,7 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list:
chunks.append(
Chunk(
content=chunk_text,
embedding=config.embedding_model_instance.embed(chunk_text),
embedding=embed_text(chunk_text),
)
)

View file

@ -20,6 +20,7 @@ from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -489,7 +490,7 @@ async def index_google_calendar_events(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
@ -507,22 +508,8 @@ async def index_google_calendar_events(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
summary_content = (
f"Google Calendar Event: {item['event_summary']}\n\n"
)
summary_content += f"Calendar: {item['calendar_id']}\n"
summary_content += f"Start: {item['start_time']}\n"
summary_content += f"End: {item['end_time']}\n"
if item["location"]:
summary_content += f"Location: {item['location']}\n"
if item["description"]:
desc_preview = item["description"][:1000]
if len(item["description"]) > 1000:
desc_preview += "..."
summary_content += f"Description: {desc_preview}\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])

View file

@ -352,7 +352,7 @@ async def index_google_drive_single_file(
await session.commit()
# Process the file
indexed, skipped, failed = await _process_single_file(
indexed, _skipped, failed = await _process_single_file(
drive_client=drive_client,
session=session,
file=file,
@ -608,7 +608,7 @@ async def _index_with_delta_sync(
{"stage": "delta_sync", "start_token": start_page_token},
)
changes, final_token, error = await fetch_all_changes(
changes, _final_token, error = await fetch_all_changes(
drive_client, start_page_token, folder_id
)
@ -1011,7 +1011,7 @@ async def _process_single_file(
pending_document.status = DocumentStatus.processing()
await session.commit()
_, error, metadata = await download_and_process_file(
_, error, _metadata = await download_and_process_file(
client=drive_client,
file=file,
search_space_id=search_space_id,

View file

@ -25,6 +25,7 @@ from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -413,7 +414,7 @@ async def index_google_gmail_messages(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
@ -432,12 +433,8 @@ async def index_google_gmail_messages(
document_metadata_for_summary,
)
else:
summary_content = f"Google Gmail Message: {item['subject']}\n\n"
summary_content += f"Sender: {item['sender']}\n"
summary_content += f"Date: {item['date_str']}\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -14,13 +14,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -356,7 +356,7 @@ async def index_jira_issues(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata = {
"issue_key": item["issue_identifier"],
"issue_title": item["issue_title"],
@ -373,14 +373,8 @@ async def index_jira_issues(
item["issue_content"], user_llm, document_metadata
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['formatted_issue'].get('status', 'Unknown')}\n\n"
if item["formatted_issue"].get("description"):
summary_content += f"Description: {item['formatted_issue'].get('description')}\n\n"
summary_content += f"Comments: {item['comment_count']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full issue content with comments
chunks = await create_document_chunks(item["issue_content"])

View file

@ -13,13 +13,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -395,7 +395,7 @@ async def index_linear_issues(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"issue_id": item["issue_identifier"],
"issue_title": item["issue_title"],
@ -412,17 +412,8 @@ async def index_linear_issues(
item["issue_content"], user_llm, document_metadata_for_summary
)
else:
# Fallback to simple summary if no LLM configured
description = item["description"]
if description and len(description) > 1000:
description = description[:997] + "..."
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n"
if description:
summary_content += f"Description: {description}\n\n"
summary_content += f"Comments: {item['comment_count']}"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["issue_content"])

View file

@ -13,13 +13,13 @@ from datetime import datetime, timedelta
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.luma_connector import LumaConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -441,7 +441,7 @@ async def index_luma_events(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_name": item["event_name"],
@ -462,29 +462,10 @@ async def index_luma_events(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Luma Event: {item['event_name']}\n\n"
if item["event_url"]:
summary_content += f"URL: {item['event_url']}\n"
summary_content += f"Start: {item['start_at']}\n"
summary_content += f"End: {item['end_at']}\n"
if item["timezone"]:
summary_content += f"Timezone: {item['timezone']}\n"
if item["location"]:
summary_content += f"Location: {item['location']}\n"
if item["city"]:
summary_content += f"City: {item['city']}\n"
if item["host_names"]:
summary_content += f"Hosts: {item['host_names']}\n"
if item["description"]:
desc_preview = item["description"][:1000]
if len(item["description"]) > 1000:
desc_preview += "..."
summary_content += f"Description: {desc_preview}\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
summary_content = (
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
)
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])

View file

@ -13,13 +13,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.notion_history import NotionHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -447,7 +447,7 @@ async def index_notion_pages(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"page_title": item["page_title"],
"page_id": item["page_id"],
@ -463,11 +463,8 @@ async def index_notion_pages(
document_metadata_for_summary,
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content'][:500]}..."
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])

View file

@ -26,6 +26,7 @@ from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -546,7 +547,7 @@ async def index_obsidian_vault(
# Generate summary
summary_content = ""
if long_context_llm:
if long_context_llm and connector.enable_summary:
summary_content, _ = await generate_document_summary(
document_string,
long_context_llm,
@ -554,7 +555,7 @@ async def index_obsidian_vault(
)
# Generate embedding
embedding = config.embedding_model_instance.embed(document_string)
embedding = embed_text(document_string)
# Add URL and summary to metadata
document_metadata["url"] = f"obsidian://{vault_name}/{relative_path}"

View file

@ -17,12 +17,12 @@ from slack_sdk.errors import SlackApiError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.slack_history import SlackHistory
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_unique_identifier_hash,
)
@ -542,9 +542,7 @@ async def index_slack_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = config.embedding_model_instance.embed(
item["combined_document_string"]
)
doc_embedding = embed_text(item["combined_document_string"])
# Update document to READY with actual content
document.title = f"{item['team_name']}#{item['channel_name']}"

View file

@ -16,12 +16,12 @@ from datetime import UTC, datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.teams_history import TeamsHistory
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_unique_identifier_hash,
)
@ -581,9 +581,7 @@ async def index_teams_messages(
# Heavy processing (embeddings, chunks)
chunks = await create_document_chunks(item["combined_document_string"])
doc_embedding = config.embedding_model_instance.embed(
item["combined_document_string"]
)
doc_embedding = embed_text(item["combined_document_string"])
# Update document to READY with actual content
document.title = f"{item['team_name']} - {item['channel_name']}"

View file

@ -13,13 +13,13 @@ from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.connectors.webcrawler_connector import WebCrawlerConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -377,7 +377,7 @@ async def index_crawled_urls(
session, user_id, search_space_id
)
if user_llm:
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"url": url,
"title": title,
@ -393,24 +393,8 @@ async def index_crawled_urls(
structured_document, user_llm, document_metadata_for_summary
)
else:
# Fallback to simple summary if no LLM configured
summary_content = f"Crawled URL: {title}\n\n"
summary_content += f"URL: {url}\n"
if description:
summary_content += f"Description: {description}\n"
if language:
summary_content += f"Language: {language}\n"
summary_content += f"Crawler: {crawler_type}\n\n"
# Add content preview
content_preview = content[:1000]
if len(content) > 1000:
content_preview += "..."
summary_content += f"Content Preview:\n{content_preview}\n"
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
summary_content = f"Crawled URL: {title}\n\nURL: {url}\n\n{content}"
summary_embedding = embed_text(summary_content)
# Process chunks
chunks = await create_document_chunks(content)

View file

@ -18,12 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
from app.services.llm_service import get_user_long_context_llm
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
convert_document_to_markdown,
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
@ -33,7 +35,6 @@ from .base import (
check_document_by_unique_identifier,
check_duplicate_document,
get_current_timestamp,
safe_set_chunks,
)
from .markdown_processor import add_received_markdown_file_document
@ -760,11 +761,7 @@ async def add_received_file_document_using_docling(
f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
)
from app.config import config
summary_embedding = config.embedding_model_instance.embed(
enhanced_summary_content
)
summary_embedding = embed_text(enhanced_summary_content)
# Process chunks
chunks = await create_document_chunks(file_in_markdown)
@ -1599,6 +1596,7 @@ async def process_file_in_background_with_document(
log_entry: Log,
connector: dict | None = None,
notification: Notification | None = None,
should_summarize: bool = False,
) -> Document | None:
"""
Process file and update existing pending document (2-phase pattern).
@ -1632,6 +1630,8 @@ async def process_file_in_background_with_document(
from app.config import config as app_config
from app.services.llm_service import get_user_long_context_llm
doc_id = document.id
try:
markdown_content = None
etl_service = None
@ -1855,7 +1855,7 @@ async def process_file_in_background_with_document(
content_hash = generate_content_hash(markdown_content, search_space_id)
existing_by_content = await check_duplicate_document(session, content_hash)
if existing_by_content and existing_by_content.id != document.id:
if existing_by_content and existing_by_content.id != doc_id:
# Duplicate content found - mark this document as failed
logging.info(
f"Duplicate content detected for {filename}, "
@ -1863,7 +1863,7 @@ async def process_file_in_background_with_document(
)
return None
# ===== STEP 3: Generate embeddings and chunks =====
# ===== STEP 3+4: Index via pipeline =====
if notification:
await NotificationService.document_processing.notify_processing_progress(
session, notification, stage="chunking"
@ -1871,57 +1871,24 @@ async def process_file_in_background_with_document(
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if user_llm:
document_metadata = {
"file_name": filename,
"etl_service": etl_service,
"document_type": "File Document",
}
summary_content, summary_embedding = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
# Fallback: use truncated content as summary
summary_content = markdown_content[:4000]
from app.config import config
summary_embedding = config.embedding_model_instance.embed(summary_content)
chunks = await create_document_chunks(markdown_content)
# ===== STEP 4: Update document to READY =====
from sqlalchemy.orm.attributes import flag_modified
document.title = filename
document.content = summary_content
document.content_hash = content_hash
document.embedding = summary_embedding
document.document_metadata = {
"FILE_NAME": filename,
"ETL_SERVICE": etl_service or "UNKNOWN",
**(document.document_metadata or {}),
}
flag_modified(document, "document_metadata")
# Use safe_set_chunks to avoid async issues
safe_set_chunks(document, chunks)
document.source_markdown = markdown_content
document.content_needs_reindexing = False
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready() # Shows checkmark in UI
await session.commit()
await session.refresh(document)
await index_uploaded_file(
markdown_content=markdown_content,
filename=filename,
etl_service=etl_service,
search_space_id=search_space_id,
user_id=user_id,
session=session,
llm=user_llm,
should_summarize=should_summarize,
)
await task_logger.log_task_success(
log_entry,
f"Successfully processed file: {filename}",
{
"document_id": document.id,
"document_id": doc_id,
"content_hash": content_hash,
"file_type": etl_service,
"chunks_count": len(chunks),
},
)
@ -1946,7 +1913,7 @@ async def process_file_in_background_with_document(
{
"error_type": type(e).__name__,
"filename": filename,
"document_id": document.id,
"document_id": doc_id,
},
)
logging.error(f"Error processing file with document: {error_message}")

View file

@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
from app.config import config
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
from app.utils.document_converters import embed_text
logger = logging.getLogger(__name__)
@ -89,7 +90,7 @@ def create_surfsense_docs_chunks(content: str) -> list[SurfsenseDocsChunk]:
return [
SurfsenseDocsChunk(
content=chunk.text,
embedding=config.embedding_model_instance.embed(chunk.text),
embedding=embed_text(chunk.text),
)
for chunk in config.chunker_instance.chunk(content)
]
@ -154,7 +155,7 @@ async def index_surfsense_docs(session: AsyncSession) -> tuple[int, int, int, in
existing_doc.title = title
existing_doc.content = content
existing_doc.content_hash = content_hash
existing_doc.embedding = config.embedding_model_instance.embed(content)
existing_doc.embedding = embed_text(content)
existing_doc.chunks = chunks
existing_doc.updated_at = datetime.now(UTC)
@ -170,7 +171,7 @@ async def index_surfsense_docs(session: AsyncSession) -> tuple[int, int, int, in
title=title,
content=content,
content_hash=content_hash,
embedding=config.embedding_model_instance.embed(content),
embedding=embed_text(content),
chunks=chunks,
updated_at=datetime.now(UTC),
)

View file

@ -1,11 +1,59 @@
import hashlib
import logging
import warnings
import numpy as np
from litellm import get_model_info, token_counter
from app.config import config
from app.db import Chunk, DocumentType
from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__)
def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model.
Checks model properties in order: max_seq_length, _max_tokens.
Falls back to 8192 (OpenAI embedding default).
"""
model = config.embedding_model_instance
for attr in ("max_seq_length", "_max_tokens"):
val = getattr(model, attr, None)
if isinstance(val, int) and val > 0:
return val
return 8192
def truncate_for_embedding(text: str) -> str:
"""Truncate text to fit within the embedding model's context window.
Uses the embedding model's own tokenizer for accurate token counting,
so the result is model-agnostic regardless of the underlying provider.
"""
max_tokens = _get_embedding_max_tokens()
if len(text) // 3 <= max_tokens:
return text
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window."""
return config.embedding_model_instance.embed(truncate_for_embedding(text))
def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
@ -146,7 +194,7 @@ async def generate_document_summary(
else:
enhanced_summary_content = summary_content
summary_embedding = config.embedding_model_instance.embed(enhanced_summary_content)
summary_embedding = embed_text(enhanced_summary_content)
return enhanced_summary_content, summary_embedding
@ -164,7 +212,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
return [
Chunk(
content=chunk.text,
embedding=config.embedding_model_instance.embed(chunk.text),
embedding=embed_text(chunk.text),
)
for chunk in config.chunker_instance.chunk(content)
]

View file

@ -29,4 +29,7 @@ if __name__ == "__main__":
config = uvicorn.Config(**config_kwargs)
server = uvicorn.Server(config)
server.run()
if sys.platform == "win32":
asyncio.run(server.serve(), loop_factory=asyncio.SelectorEventLoop)
else:
server.run()

View file

@ -17,6 +17,7 @@ dependencies = [
"kokoro>=0.9.4",
"linkup-sdk>=0.2.4",
"llama-cloud-services>=0.6.25",
"Markdown>=3.7",
"markdownify>=0.14.1",
"notion-client>=2.3.0",
"numpy>=1.24.0",
@ -65,11 +66,16 @@ dependencies = [
"pypandoc_binary>=1.16.2",
"typst>=0.14.0",
"deepagents>=0.4.3",
"langchain-daytona>=0.0.2",
]
[dependency-groups]
dev = [
"ruff>=0.12.5",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"pytest-mock>=3.14",
"httpx>=0.28.1",
]
[tool.ruff]
@ -157,10 +163,27 @@ line-ending = "auto"
[tool.ruff.lint.isort]
# Group imports by type
known-first-party = ["app"]
known-first-party = ["app", "tests"]
force-single-line = false
combine-as-imports = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
markers = [
"unit: pure logic tests, no DB or external services",
"integration: tests that require a real PostgreSQL database"
]
filterwarnings = [
"ignore::UserWarning:chonkie",
]
[tool.setuptools.packages.find]
where = ["."]
include = ["app*", "alembic*"]

View file

View file

@ -0,0 +1,61 @@
"""Root conftest — shared fixtures available to all test modules."""
from __future__ import annotations
import os
_DEFAULT_TEST_DB = (
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
)
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
# Force the app to use the test database regardless of any pre-existing
# DATABASE_URL in the environment (e.g. from .env or shell profile).
os.environ["DATABASE_URL"] = TEST_DATABASE_URL
import pytest # noqa: E402
from app.db import DocumentType # noqa: E402
from app.indexing_pipeline.connector_document import ConnectorDocument # noqa: E402
# ---------------------------------------------------------------------------
# Unit test fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_user_id() -> str:
return "00000000-0000-0000-0000-000000000001"
@pytest.fixture
def sample_search_space_id() -> int:
return 1
@pytest.fixture
def sample_connector_id() -> int:
return 42
@pytest.fixture
def make_connector_document():
"""
Generic factory for unit tests. Overridden in tests/integration/conftest.py
with real DB-backed IDs for integration tests.
"""
def _make(**overrides):
defaults = {
"title": "Test Document",
"source_markdown": "## Heading\n\nSome content.",
"unique_id": "test-id-001",
"document_type": DocumentType.CLICKUP_CONNECTOR,
"search_space_id": 1,
"connector_id": 1,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
return _make

View file

View file

@ -0,0 +1,51 @@
# SurfSense Test Document
## Overview
This is a **sample markdown document** used for end-to-end testing of the manual
document upload pipeline. It includes various markdown formatting elements.
## Key Features
- Document upload and processing
- Automatic chunking of content
- Embedding generation for semantic search
- Real-time status tracking via ElectricSQL
## Technical Architecture
### Backend Stack
The SurfSense backend is built with:
1. **FastAPI** for the REST API
2. **PostgreSQL** with pgvector for vector storage
3. **Celery** with Redis for background task processing
4. **Docling/Unstructured** for document parsing (ETL)
### Processing Pipeline
Documents go through a multi-stage pipeline:
| Stage | Description |
|-------|-------------|
| Upload | File received via API endpoint |
| Parsing | Content extracted using ETL service |
| Chunking | Text split into semantic chunks |
| Embedding | Vector representations generated |
| Storage | Chunks stored with embeddings in pgvector |
## Code Example
```python
async def process_document(file_path: str) -> Document:
content = extract_content(file_path)
chunks = create_chunks(content)
embeddings = generate_embeddings(chunks)
return store_document(chunks, embeddings)
```
## Conclusion
This document serves as a test fixture to validate the complete document processing
pipeline from upload through to chunk creation and embedding storage.

Binary file not shown.

View file

@ -0,0 +1,34 @@
SurfSense Document Upload Test
This is a sample text document used for end-to-end testing of the manual document
upload pipeline in SurfSense. The document contains multiple paragraphs to ensure
that the chunking system has enough content to work with.
Artificial Intelligence and Machine Learning
Artificial intelligence (AI) is a broad field of computer science concerned with
building smart machines capable of performing tasks that typically require human
intelligence. Machine learning is a subset of AI that enables systems to learn and
improve from experience without being explicitly programmed.
Natural Language Processing
Natural language processing (NLP) is a subfield of linguistics, computer science,
and artificial intelligence concerned with the interactions between computers and
human language. Key applications include machine translation, sentiment analysis,
text summarization, and question answering systems.
Vector Databases and Semantic Search
Vector databases store data as high-dimensional vectors, enabling efficient
similarity search operations. When combined with embedding models, they power
semantic search systems that understand the meaning behind queries rather than
relying on exact keyword matches. This technology is fundamental to modern
retrieval-augmented generation (RAG) systems.
Document Processing Pipelines
Modern document processing pipelines involve several stages: extraction, transformation,
chunking, embedding generation, and storage. Each stage plays a critical role in
converting raw documents into searchable, structured knowledge that can be retrieved
and used by AI systems for accurate information retrieval and generation.

View file

@ -0,0 +1,168 @@
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config as app_config
from app.db import (
Base,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.indexing_pipeline.connector_document import ConnectorDocument
from tests.conftest import TEST_DATABASE_URL
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=NullPool,
echo=False,
# Required for asyncpg + savepoints: disables prepared statement cache
# to prevent "another operation is in progress" errors during savepoint rollbacks.
connect_args={"prepared_statement_cache_size": 0},
)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.run_sync(Base.metadata.create_all)
yield engine
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
# DROP SCHEMA CASCADE handles this without needing topological sort.
async with engine.begin() as conn:
await conn.execute(text("DROP SCHEMA public CASCADE"))
await conn.execute(text("CREATE SCHEMA public"))
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(async_engine) -> AsyncSession:
# Bind the session to a connection that holds an outer transaction.
# join_transaction_mode="create_savepoint" makes session.commit() release
# a SAVEPOINT instead of committing the outer transaction, so the final
# transaction.rollback() undoes everything — including commits made by the
# service under test — leaving the DB clean for the next test.
async with async_engine.connect() as conn:
transaction = await conn.begin()
async with AsyncSession(
bind=conn,
expire_on_commit=False,
join_transaction_mode="create_savepoint",
) as session:
yield session
await transaction.rollback()
@pytest_asyncio.fixture
async def db_user(db_session: AsyncSession) -> User:
user = User(
id=uuid.uuid4(),
email="test@surfsense.net",
hashed_password="hashed",
is_active=True,
is_superuser=False,
is_verified=True,
)
db_session.add(user)
await db_session.flush()
return user
@pytest_asyncio.fixture
async def db_connector(
db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace"
) -> SearchSourceConnector:
connector = SearchSourceConnector(
name="Test Connector",
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
config={},
search_space_id=db_search_space.id,
user_id=db_user.id,
)
db_session.add(connector)
await db_session.flush()
return connector
@pytest_asyncio.fixture
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
space = SearchSpace(
name="Test Space",
user_id=db_user.id,
)
db_session.add(space)
await db_session.flush()
return space
@pytest.fixture
def patched_summarize(monkeypatch) -> AsyncMock:
mock = AsyncMock(return_value="Mocked summary.")
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
mock,
)
return mock
@pytest.fixture
def patched_summarize_raises(monkeypatch) -> AsyncMock:
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
mock,
)
return mock
@pytest.fixture
def patched_embed_text(monkeypatch) -> MagicMock:
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
mock,
)
return mock
@pytest.fixture
def patched_chunk_text(monkeypatch) -> MagicMock:
mock = MagicMock(return_value=["Test chunk content."])
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock,
)
return mock
@pytest.fixture
def make_connector_document(db_connector, db_user):
"""Integration-scoped override: uses real DB connector and user IDs."""
def _make(**overrides):
defaults = {
"title": "Test Document",
"source_markdown": "## Heading\n\nSome content.",
"unique_id": "test-id-001",
"document_type": DocumentType.CLICKUP_CONNECTOR,
"search_space_id": db_connector.search_space_id,
"connector_id": db_connector.id,
"created_by_id": str(db_user.id),
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
return _make

View file

@ -0,0 +1,289 @@
"""Integration conftest — runs the FastAPI app in-process via ASGITransport.
Prerequisites: PostgreSQL + pgvector only.
External system boundaries are mocked:
- LLM summarization, text embedding, text chunking (external APIs)
- Redis heartbeat (external infrastructure)
- Task dispatch is swapped via DI (InlineTaskDispatcher)
"""
from __future__ import annotations
import contextlib
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import asyncpg
import httpx
import pytest
from httpx import ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from app.app import app
from app.config import config as app_config
from app.db import Base
from app.services.task_dispatcher import get_task_dispatcher
from tests.integration.conftest import TEST_DATABASE_URL
from tests.utils.helpers import (
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
_ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://")
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Inline task dispatcher (replaces Celery via DI — not a mock)
# ---------------------------------------------------------------------------
class InlineTaskDispatcher:
"""Processes files synchronously in the calling coroutine.
Swapped in via FastAPI dependency_overrides so the upload endpoint
processes documents inline instead of dispatching to Celery.
Exceptions are caught to match Celery's fire-and-forget semantics —
the processing function already marks documents as failed internally.
"""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
should_summarize: bool = False,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
_process_file_with_document,
)
with contextlib.suppress(Exception):
await _process_file_with_document(
document_id,
temp_path,
filename,
search_space_id,
user_id,
should_summarize=should_summarize,
)
app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
# ---------------------------------------------------------------------------
# Database setup (ASGITransport skips the app lifespan)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def _ensure_tables():
"""Create DB tables and extensions once per session."""
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
await conn.run_sync(Base.metadata.create_all)
await engine.dispose()
# ---------------------------------------------------------------------------
# Auth & search space (session-scoped, via the in-process app)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def auth_token(_ensure_tables) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_auth_token(c)
@pytest.fixture(scope="session")
async def search_space_id(auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_search_space_id(c, auth_token)
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
return auth_headers(auth_token)
# ---------------------------------------------------------------------------
# Per-test HTTP client & cleanup
# ---------------------------------------------------------------------------
@pytest.fixture
async def client() -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client using ASGITransport (no running server)."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=180.0
) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(search_space_id: int):
"""Delete stale documents from previous runs before the session starts."""
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
deleted = int(result.split()[-1])
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) "
f"from search space {search_space_id}"
)
finally:
await conn.close()
yield
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
cleanup_doc_ids: list[int],
):
"""Delete test documents after every test (API first, DB fallback)."""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB for setup, API for verification)
# ---------------------------------------------------------------------------
async def _get_user_page_usage(email: str) -> tuple[int, int]:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""Manipulate the test user's page limits (direct DB for setup only).
Automatically restores original values after each test.
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)
# ---------------------------------------------------------------------------
# Mock external system boundaries
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _mock_external_apis(monkeypatch):
"""Mock LLM, embedding, and chunking — these are external API boundaries."""
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Mocked summary."),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
MagicMock(return_value=[0.1] * _EMBEDDING_DIM),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
MagicMock(return_value=["Test chunk content."]),
)
@pytest.fixture(autouse=True)
def _mock_redis_heartbeat(monkeypatch):
"""Mock Redis heartbeat — Redis is an external infrastructure boundary."""
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._start_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._stop_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
AsyncMock(),
)

View file

@ -0,0 +1,337 @@
"""
Integration tests for the document upload HTTP API.
Covers the API contract, auth, duplicate detection, and error handling.
Pipeline internals are tested in the ``indexing_pipeline`` suite.
Requires PostgreSQL + pgvector.
"""
from __future__ import annotations
import shutil
from pathlib import Path
import httpx
import pytest
from tests.utils.helpers import (
FIXTURES_DIR,
poll_document_status,
upload_file,
upload_multiple_files,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Upload smoke tests (one per distinct code-path: direct-read & ETL)
# ---------------------------------------------------------------------------
class TestTxtFileUpload:
"""Upload a plain-text file (direct-read path) via the HTTP API."""
async def test_upload_txt_returns_document_id(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] >= 1
assert len(body["document_ids"]) >= 1
cleanup_doc_ids.extend(body["document_ids"])
async def test_txt_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
class TestPdfFileUpload:
"""Upload a PDF (ETL extraction path) via the HTTP API."""
async def test_pdf_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
# ---------------------------------------------------------------------------
# Test D: Upload multiple files in a single request
# ---------------------------------------------------------------------------
class TestMultiFileUpload:
"""Upload several files at once and verify the API response contract."""
async def test_multi_upload_returns_all_ids(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] == 2
assert len(body["document_ids"]) == 2
cleanup_doc_ids.extend(body["document_ids"])
# ---------------------------------------------------------------------------
# Test E: Duplicate file upload (same file uploaded twice)
# ---------------------------------------------------------------------------
class TestDuplicateFileUpload:
"""
Uploading the exact same file a second time should be detected as a
duplicate via ``unique_identifier_hash``.
"""
async def test_duplicate_file_is_skipped(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
resp2 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp2.status_code == 200
body2 = resp2.json()
assert body2["skipped_duplicates"] >= 1
assert len(body2["duplicate_document_ids"]) >= 1
cleanup_doc_ids.extend(body2.get("document_ids", []))
# ---------------------------------------------------------------------------
# Test F: Duplicate content detection (different name, same content)
# ---------------------------------------------------------------------------
class TestDuplicateContentDetection:
"""
Uploading a file with a different name but identical content should be
detected as duplicate content via ``content_hash``.
"""
async def test_same_content_different_name_detected(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
tmp_path: Path,
):
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
src = FIXTURES_DIR / "sample.txt"
dest = tmp_path / "renamed_sample.txt"
shutil.copy2(src, dest)
with open(dest, "rb") as f:
resp2 = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files={"files": ("renamed_sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
assert second_ids, (
"Expected at least one document id for renamed duplicate content upload"
)
statuses = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id
)
for did in second_ids:
assert statuses[did]["status"]["state"] == "failed"
assert "duplicate" in statuses[did]["status"].get("reason", "").lower()
# ---------------------------------------------------------------------------
# Test G: Empty / corrupt file handling
# ---------------------------------------------------------------------------
class TestEmptyFileUpload:
"""An empty file should be processed but ultimately fail gracefully."""
async def test_empty_pdf_fails(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
assert doc_ids, "Expected at least one document id for empty PDF upload"
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
assert statuses[did]["status"].get("reason"), (
"Failed document should include a reason"
)
# ---------------------------------------------------------------------------
# Test H: Upload without authentication
# ---------------------------------------------------------------------------
class TestUnauthenticatedUpload:
"""Requests without a valid JWT should be rejected."""
async def test_upload_without_auth_returns_401(
self,
client: httpx.AsyncClient,
search_space_id: int,
):
file_path = FIXTURES_DIR / "sample.txt"
with open(file_path, "rb") as f:
resp = await client.post(
"/api/v1/documents/fileupload",
files={"files": ("sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Test I: Upload with no files attached
# ---------------------------------------------------------------------------
class TestNoFilesUpload:
"""Submitting the form with zero files should return a validation error."""
async def test_no_files_returns_error(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code in {400, 422}
# ---------------------------------------------------------------------------
# Test K: Searchability after upload
# ---------------------------------------------------------------------------
class TestDocumentSearchability:
"""After upload reaches ready, the document must appear in the title search."""
async def test_uploaded_document_appears_in_search(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
search_resp = await client.get(
"/api/v1/documents/search",
headers=headers,
params={"title": "sample", "search_space_id": search_space_id},
)
assert search_resp.status_code == 200
result_ids = [d["id"] for d in search_resp.json()["items"]]
assert doc_ids[0] in result_ids, (
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
)

View file

@ -0,0 +1,332 @@
"""
Integration tests for page-limit enforcement during document upload.
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
columns directly in the database (setup only) and then exercise the upload
pipeline to verify that:
- Uploads are rejected *before* ETL when the limit is exhausted.
- ``pages_used`` increases after a successful upload (verified via API).
- A ``page_limit_exceeded`` notification is created on rejection.
- ``pages_used`` is not modified when a document fails processing.
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
so no additional processing time is introduced.
Prerequisites:
- PostgreSQL + pgvector
"""
from __future__ import annotations
import httpx
import pytest
from tests.utils.helpers import (
get_notifications,
poll_document_status,
upload_file,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helper: read pages_used through the public API
# ---------------------------------------------------------------------------
async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int:
"""Fetch the current user's pages_used via the /users/me API."""
resp = await client.get("/users/me", headers=headers)
assert resp.status_code == 200, (
f"GET /users/me failed ({resp.status_code}): {resp.text}"
)
return resp.json()["pages_used"]
# ---------------------------------------------------------------------------
# Test A: Successful upload increments pages_used
# ---------------------------------------------------------------------------
class TestPageUsageIncrementsOnSuccess:
"""After a successful PDF upload the user's ``pages_used`` must grow."""
async def test_pages_used_increases_after_pdf_upload(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
used = await _get_pages_used(client, headers)
assert used > 0, "pages_used should have increased after successful processing"
# ---------------------------------------------------------------------------
# Test B: Upload rejected when page limit is fully exhausted
# ---------------------------------------------------------------------------
class TestUploadRejectedWhenLimitExhausted:
"""
When ``pages_used == pages_limit`` (zero remaining) the document
should reach ``failed`` status with a page-limit reason.
"""
async def test_pdf_fails_when_no_pages_remaining(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
reason = statuses[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)
async def test_pages_used_unchanged_after_limit_rejection(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=50, pages_limit=50)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
used = await _get_pages_used(client, headers)
assert used == 50, (
f"pages_used should remain 50 after rejected upload, got {used}"
)
# ---------------------------------------------------------------------------
# Test C: Page-limit notification is created on rejection
# ---------------------------------------------------------------------------
class TestPageLimitNotification:
"""A ``page_limit_exceeded`` notification must be created when upload
is rejected due to the limit."""
async def test_page_limit_exceeded_notification_created(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
notifications = await get_notifications(
client,
headers,
type_filter="page_limit_exceeded",
search_space_id=search_space_id,
)
assert len(notifications) >= 1, (
"Expected at least one page_limit_exceeded notification"
)
latest = notifications[0]
assert (
"page limit" in latest["title"].lower()
or "page limit" in latest["message"].lower()
), (
f"Notification should mention page limit: title={latest['title']!r}, "
f"message={latest['message']!r}"
)
# ---------------------------------------------------------------------------
# Test D: Successful upload creates a completed document_processing notification
# ---------------------------------------------------------------------------
class TestDocumentProcessingNotification:
"""A ``document_processing`` notification with ``completed`` status must
exist after a successful upload."""
async def test_processing_completed_notification_exists(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
notifications = await get_notifications(
client,
headers,
type_filter="document_processing",
search_space_id=search_space_id,
)
completed = [
n
for n in notifications
if n.get("metadata", {}).get("processing_stage") == "completed"
]
assert len(completed) >= 1, (
"Expected at least one document_processing notification with 'completed' stage"
)
# ---------------------------------------------------------------------------
# Test E: pages_used unchanged when a document fails for non-limit reasons
# ---------------------------------------------------------------------------
class TestPagesUnchangedOnProcessingFailure:
"""If a document fails during ETL (e.g. empty/corrupt file) rather than
a page-limit rejection, ``pages_used`` should remain unchanged."""
async def test_pages_used_stable_on_etl_failure(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=10, pages_limit=1000)
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
if doc_ids:
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
used = await _get_pages_used(client, headers)
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
# ---------------------------------------------------------------------------
# Test F: Second upload rejected after first consumes remaining quota
# ---------------------------------------------------------------------------
class TestSecondUploadExceedsLimit:
"""Upload one PDF successfully, consuming the quota, then verify a
second upload is rejected."""
async def test_second_upload_rejected_after_quota_consumed(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1)
resp1 = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
statuses1 = await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id, timeout=300.0
)
for did in first_ids:
assert statuses1[did]["status"]["state"] == "ready"
resp2 = await upload_file(
client,
headers,
"sample.pdf",
search_space_id=search_space_id,
filename_override="sample_copy.pdf",
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
statuses2 = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id, timeout=300.0
)
for did in second_ids:
assert statuses2[did]["status"]["state"] == "failed"
reason = statuses2[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)

View file

@ -0,0 +1,145 @@
"""
Integration tests for backend file upload limit enforcement.
These tests verify that the API rejects uploads that exceed:
- Max files per upload (10)
- Max per-file size (50 MB)
- Max total upload size (200 MB)
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
enforced server-side to protect against direct API calls.
Prerequisites:
- PostgreSQL + pgvector
"""
from __future__ import annotations
import io
import httpx
import pytest
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Test A: File count limit
# ---------------------------------------------------------------------------
class TestFileCountLimit:
"""Uploading more than 10 files in a single request should be rejected."""
async def test_11_files_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(11)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "too many files" in resp.json()["detail"].lower()
async def test_10_files_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(10)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test B: Per-file size limit
# ---------------------------------------------------------------------------
class TestPerFileSizeLimit:
"""A single file exceeding 50 MB should be rejected."""
async def test_oversized_file_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("big.pdf", oversized, "application/pdf"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "per-file limit" in resp.json()["detail"].lower()
async def test_file_at_limit_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test C: Total upload size limit
# ---------------------------------------------------------------------------
class TestTotalSizeLimit:
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
async def test_total_size_over_200mb_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
chunk_size = 45 * 1024 * 1024 # 45 MB each
files = [
(
"files",
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
)
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "total upload size" in resp.json()["detail"].lower()

View file

@ -0,0 +1,99 @@
import pytest
from sqlalchemy import select
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
"""Document status is READY after successful indexing."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
"""Document content is set to the LLM-generated summary."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
assert document.content == "Mocked summary."
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
"""Chunks derived from the source markdown are persisted in the DB."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
chunks_result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document.id)
)
chunks = chunks_result.scalars().all()
assert len(chunks) == 1
assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
with pytest.raises(RuntimeError):
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)

View file

@ -0,0 +1,341 @@
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_sets_status_ready(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Document status is READY after successful indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_summary_when_should_summarize_true(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Document content is set to the LLM-generated summary when should_summarize=True."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert reloaded.content == "Mocked summary."
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_source_markdown_when_should_summarize_false(
db_session,
db_search_space,
make_connector_document,
):
"""Document content is set to source_markdown verbatim when should_summarize=False."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=False,
source_markdown="## Raw content",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=None)
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert reloaded.content == "## Raw content"
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_chunks_written_to_db(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Chunks derived from source_markdown are persisted in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
chunks = result.scalars().all()
assert len(chunks) == 1
assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_embedding_written_to_db(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Document embedding vector is persisted in the DB after indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert reloaded.embedding is not None
assert len(reloaded.embedding) == _EMBEDDING_DIM
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_updated_at_advances_after_indexing(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""updated_at timestamp is later after indexing than it was at prepare time."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_pending = result.scalars().first().updated_at
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_ready = result.scalars().first().updated_at
assert updated_at_ready > updated_at_pending
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_no_llm_falls_back_to_source_markdown(
db_session,
db_search_space,
make_connector_document,
):
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=True,
source_markdown="## Fallback content",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=None)
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "## Fallback content"
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_fallback_summary_used_when_llm_unavailable(
db_session,
db_search_space,
make_connector_document,
):
"""fallback_summary is used as content when llm=None and should_summarize=True."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=True,
source_markdown="## Full raw content",
fallback_summary="Short pre-built summary.",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document_id = prepared[0].id
await service.index(prepared[0], connector_doc, llm=None)
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "Short pre-built summary."
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_reindex_replaces_old_chunks(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Re-indexing a document replaces its old chunks rather than appending."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v1",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
updated_doc = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v2",
)
re_prepared = await service.prepare_for_indexing([updated_doc])
await service.index(re_prepared[0], updated_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
chunks = result.scalars().all()
assert len(chunks) == 1
@pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_llm_error_sets_status_failed(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""Document status is FAILED when the LLM raises during indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
@pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_llm_error_leaves_no_partial_data(
db_session,
db_search_space,
make_connector_document,
mocker,
):
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first()
assert reloaded.embedding is None
assert reloaded.content == "Pending..."
chunks_result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
assert chunks_result.scalars().all() == []

Some files were not shown because too many files have changed in this diff Show more